blob: 357e7fe84b0349b64bf273c083b0d89454b1c15e [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Charles Xu87c13502020-08-06 12:17:26 +020018import numpy as np
19
Tim Hallc30f4952020-06-15 20:47:35 +010020from .data_type import BaseType
21from .data_type import DataType
Dwight Lidman8359a472020-09-28 15:53:40 +020022from .numeric_util import is_integer
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020023from .operation import get_slice_offsets
Louis Verhaardaee5d752020-09-30 09:01:52 +020024from .operation import Op
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025
26
Michael McGeagh37ded342020-10-01 15:37:44 +010027# Custom decorator function to allow formatting docstrings containing "{}"
28def docstring_format_args(args):
29 def docstring(func):
30 func.__doc__ = func.__doc__.format(*args)
31 return func
32
33 return docstring
34
35
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020036def warn_cpu(op, msg):
37 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010038
39
40class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010041 # Categorised lists of supported operators
Louis Verhaardaee5d752020-09-30 09:01:52 +020042 npu_pre_ops = set((Op.SplitSliceRead,))
43 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
44 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
45 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
46 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
47 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
48 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
49 resizing_ops = set((Op.ResizeBilinear,))
50 fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010051 mac_main_ops = (
52 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020053 set((Op.BlockLSTM,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010054 # convolutions
55 | convolution_ops
56 # depth-wise convolutions
57 | depthwise_convolution_ops
58 # transpose convolutions
59 | transpose_convolution_ops
60 # pooling
61 | pooling_ops
62 # resizing/upscaling
63 | resizing_ops
64 # FC layers
65 | fc_vector_products
66 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020067 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
68 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
69 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
70 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010071 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
72 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010073 supported_int32_tensor_ops = (
Louis Verhaardaee5d752020-09-30 09:01:52 +020074 set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010075 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020076 activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010077 npu_post_ops = (
Michael McGeagh1eeea512020-09-30 14:23:09 +010078 # activation functions
Louis Verhaardaee5d752020-09-30 09:01:52 +020079 activation_ops
80 # concatenation write direction
81 | set((Op.ConcatSliceWrite,))
82 # Quantization
83 | set((Op.Quantize,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010084 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020085 split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
86 concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
87 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
88 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
89 supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010090 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010091 supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
92 # Defined ranges for allowed values:
93 tens_dim_range = (1, 65535)
Michael McGeagh1eeea512020-09-30 14:23:09 +010094
Fredrik Svedberg880e7352020-08-25 11:31:47 +020095 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +010096 # Setup supported operator restriction checkers
97 self.supported_operator_restrictions = {}
98 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +010099 {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100100 )
101 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100102 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100103 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200104 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100105 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100106 )
107 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100108 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100109 )
110 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100111 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100112 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100113 self.supported_operator_restrictions.update(
114 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
115 )
116 self.supported_operator_restrictions.update(
117 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
118 )
119 self.supported_operator_restrictions.update(
120 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
121 )
122 self.supported_operator_restrictions.update(
123 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
124 )
Michael McGeagh37ded342020-10-01 15:37:44 +0100125 # Setup the generic constraints
126 self.generic_constraints = []
127 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
128 self.generic_constraints.append(SupportedOperators.constraint_tens_shapeless)
129 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
130 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
131 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
132 self.generic_constraints.append(SupportedOperators.constraint_faf)
133 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
Dwight Lidman8359a472020-09-28 15:53:40 +0200134 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
Tim Hall79d07d22020-04-27 18:20:16 +0100135
136 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100137 if op.type not in SupportedOperators.supported_operators:
Tim Hall79d07d22020-04-27 18:20:16 +0100138 return False
Michael McGeagh37ded342020-10-01 15:37:44 +0100139 for constraint in self.generic_constraints:
140 valid, extra = constraint(op)
141 if not valid:
142 print('Warning: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type))
143 print(" - {}".format(constraint.__doc__))
144 if extra:
145 print(" {}".format(extra))
146 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100147 if op.type in self.supported_operator_restrictions:
148 return self.supported_operator_restrictions[op.type](op)
149 return True
150
Michael McGeagh37ded342020-10-01 15:37:44 +0100151 @staticmethod
152 def constraint_tens_defined_shape(op):
153 "Input(s) and Output Tensors must have a defined shape"
154 valid = True
155 extra = []
156 for tens in op.inputs + op.outputs:
157 if tens:
158 valid &= tens.has_fully_defined_shape()
159 extra.append("shape={}".format(tens.shape))
160 return valid, " ".join(extra)
161
Michael McGeagh1eeea512020-09-30 14:23:09 +0100162 @classmethod
Michael McGeagh37ded342020-10-01 15:37:44 +0100163 @docstring_format_args([shapeless_input_ops])
164 def constraint_tens_shapeless(cls, op):
165 "Scalar or Broadcasting Tensors are only valid for Input Tensors, and when op type is: {}"
166 valid = True
167 extra = []
168 for tens in op.inputs:
169 if tens and tens.shape == []:
170 valid &= op.type in cls.shapeless_input_ops
171 extra.append("shape={}".format(tens.shape))
172 for tens in op.outputs:
173 if tens.shape == []:
174 valid = False
175 extra.append("shape={}".format(tens.shape))
176 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100177
Michael McGeagh37ded342020-10-01 15:37:44 +0100178 @staticmethod
179 def constraint_tens_shape_size(op):
180 "Input(s) and Output Tensors must not be greater than 4D"
181 valid = True
182 extra = []
183 for tens in op.inputs + op.outputs:
184 if tens:
185 valid &= len(tens.shape) <= 4
186 extra.append("shape={}".format(tens.shape))
187 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100188
Michael McGeagh37ded342020-10-01 15:37:44 +0100189 @classmethod
190 @docstring_format_args([supported_dtypes, supported_int32_tensor_ops])
191 def constraint_tens_dtype(cls, op):
192 "Tensors must be of type: {}. Tensors which are int32 are only valid when op type is: {}"
193 valid = True
194 extra = []
195 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
196 tensors = tensors if tensors else op.inputs
197 for tens in tensors:
198 if tens.dtype == DataType.int32:
199 valid &= op.type in cls.supported_int32_tensor_ops
200 else:
201 valid &= tens.dtype in cls.supported_dtypes
202 extra.append("dtype={}".format(tens.dtype))
203 return valid, " ".join(extra)
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200204
Michael McGeagh37ded342020-10-01 15:37:44 +0100205 @classmethod
206 @docstring_format_args(tens_dim_range)
207 def constraint_tens_dimension(cls, op):
208 "Tensor dimensions must be in the range {}-{} (inclusive)"
209 tens_min, tens_max = cls.tens_dim_range
210 valid = True
211 extra = []
212 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
213 tensors = tensors if tensors else op.inputs
214 for tens in tensors:
215 valid &= all(tens_min <= dim <= tens_max for dim in tens.shape)
216 extra.append("shape={}".format(tens.shape))
217 return valid, " ".join(extra)
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200218
Michael McGeagh37ded342020-10-01 15:37:44 +0100219 @classmethod
220 @docstring_format_args([supported_fused_activations])
221 def constraint_faf(cls, op):
222 "The fused activation function (if present) must be one of type: {}"
Louis Verhaardaee5d752020-09-30 09:01:52 +0200223 faf = op.activation
Michael McGeagh37ded342020-10-01 15:37:44 +0100224 valid = (faf is None) or (faf in cls.supported_fused_activations)
225 extra = "fused_activation_function={}".format(faf)
226 return valid, extra
227
228 @staticmethod
229 def constraint_tens_quant_scale(op):
230 "Tensors with quantization scales must be finite"
231 valid = True
232 extra = []
233 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
234 for tens in tensors:
235 if tens.quantization is not None and tens.quantization.scale_f32 is not None:
236 valid &= not np.isinf(tens.quantization.scale_f32).any()
237 extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32))
238 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100239
Dwight Lidman8359a472020-09-28 15:53:40 +0200240 @staticmethod
241 def constraint_tens_quant_none_check(op):
242 "Tensors must have quantization parameters"
243 valid = True
244 extra = []
245 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
246 for tens in tensors:
247 if tens.quantization is None:
248 valid = False
249 extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
250 return valid, ", ".join(extra)
251
Michael McGeagh1eeea512020-09-30 14:23:09 +0100252 @classmethod
253 def check_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100254 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200255 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
256 if not is_integer(stride_w) or not is_integer(stride_h):
257 print("Warning:", op.type, "has non-integer stride, placing on CPU")
258 return False
259 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
260 print(
261 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
262 op.type, stride_w, stride_h
263 )
264 )
Tim Hall79d07d22020-04-27 18:20:16 +0100265 return False
266
267 # check dilation
268 dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
269 dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
Dwight Lidman8359a472020-09-28 15:53:40 +0200270 if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
271 print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
272 return False
273 if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
274 print(
275 "Warning:",
276 op.type,
277 "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
278 dilation_w_factor, dilation_h_factor
279 ),
280 )
Tim Hall79d07d22020-04-27 18:20:16 +0100281 return False
282
283 # check data type
Jacob Bohlin49d92122020-08-19 14:36:46 +0200284 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100285 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200286 print("Warning: only 8-bit weights are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100287 return False
288
Michael McGeagh1eeea512020-09-30 14:23:09 +0100289 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200290 return False
291
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200292 # check kernel size [HWIO]
Dwight Lidman8359a472020-09-28 15:53:40 +0200293 dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
294 dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200295
Dwight Lidman8359a472020-09-28 15:53:40 +0200296 # kernel limits
297 if not 1 <= dilated_weight_h <= 64:
298 print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
299 return False
300 if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
301 print(
302 "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
303 "placing on CPU",
304 )
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200305 return False
306
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200307 # check non const weights
308 if weight_tensor.values is None:
Dwight Lidman8359a472020-09-28 15:53:40 +0200309 print("Warning:", op.type, "has non-constant weights, placing on CPU")
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200310 return False
311
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200312 # check weight sums over [HWI]
313 zero_point = weight_tensor.quantization.zero_point
314 quant_weights = weight_tensor.quant_values.astype(np.int64)
315 weights = quant_weights - zero_point
316 totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
317
318 if np.amax(totals) > 127 * 65536:
Dwight Lidman8359a472020-09-28 15:53:40 +0200319 print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
Tim Hall79d07d22020-04-27 18:20:16 +0100320 return False
321
322 # check batch size
323 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200324 print("Warning: only batch sizes of 1 are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100325 return False
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200326
Tim Hall79d07d22020-04-27 18:20:16 +0100327 return True
328
Michael McGeagh1eeea512020-09-30 14:23:09 +0100329 @classmethod
330 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100331 # check depth
Louis Verhaardaee5d752020-09-30 09:01:52 +0200332 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100333 if op.attrs["depth_multiplier"] > 1 and not (
334 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
335 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200336 print(
337 "Warning: for depth multipliers > 1,",
338 "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
339 "Placing on CPU",
340 )
Tim Hall79d07d22020-04-27 18:20:16 +0100341 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100342 return cls.check_convolution_restrictions(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100343
Michael McGeagh1eeea512020-09-30 14:23:09 +0100344 @classmethod
345 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200346 # check stride
347 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
Dwight Lidman8359a472020-09-28 15:53:40 +0200348 if stride_h != 2 or stride_w != 2:
349 print("Warning: stride must be equal to 2, placing on CPU")
Jacob Bohlincf7da102020-05-20 09:03:40 +0200350 return False
351
352 # check output dimensions
353 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
354 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
355 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
356 if op.attrs["padding"] == b"SAME":
357 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
Dwight Lidman8359a472020-09-28 15:53:40 +0200358 print(
359 "Warning: for",
360 op.type,
361 "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
362 "Placing on CPU",
363 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200364 return False
365 elif op.attrs["padding"] == b"VALID":
366 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100367 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
368 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
369 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200370 print(
371 "Warning: for",
372 op.type,
373 "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
374 "minus difference between kernel size and stride. Placing on CPU",
375 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200376 return False
377
Michael McGeagh1eeea512020-09-30 14:23:09 +0100378 return cls.check_convolution_restrictions(op)
Jacob Bohlincf7da102020-05-20 09:03:40 +0200379
Michael McGeagh1eeea512020-09-30 14:23:09 +0100380 @classmethod
381 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100382 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200383 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
384 if not is_integer(stride_w) or not is_integer(stride_h):
385 print("Warning:", op.type, "has non-integer stride, placing on CPU")
386 return False
387 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
388 print(
389 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
390 op.type, stride_w, stride_h
391 )
392 )
Tim Hall79d07d22020-04-27 18:20:16 +0100393 return False
394
395 # check data type
Louis Verhaardaee5d752020-09-30 09:01:52 +0200396 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100397 if ifm_tensor.dtype != ofm_tensor.dtype:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200398 if op.type != Op.ReduceSum:
Dwight Lidman8359a472020-09-28 15:53:40 +0200399 print("Warning: input data type doesn't match output data type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200400 return False
401 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100402
403 # check batch size
404 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200405 print("Warning: input batch size must be 1, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100406 return False
407
Dwight Lidman8359a472020-09-28 15:53:40 +0200408 # check kernel size
409 kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
410 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
411 if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
412 print(
413 "Warning:",
414 op.type,
415 "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
416 kernel_w, kernel_h
417 ),
418 )
Tim Hall79d07d22020-04-27 18:20:16 +0100419 return False
Dwight Lidman8359a472020-09-28 15:53:40 +0200420 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
421 if not 1 <= kernel_w * kernel_h <= 256 * 256:
422 print(
423 "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
424 256 * 256
425 ),
426 "placing on CPU",
427 )
428 return False
429 if not 1 <= kernel_h <= 256:
430 print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100431 return False
432
Tim Hall79d07d22020-04-27 18:20:16 +0100433 return True
434
Michael McGeagh1eeea512020-09-30 14:23:09 +0100435 @classmethod
436 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200437 # check unsupported upscaling factor
Louis Verhaardaee5d752020-09-30 09:01:52 +0200438 if op.type == Op.ResizeBilinear:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200439 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
440 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200441 if op.inputs[0].shape == op.outputs[0].shape:
442 return True
Charles Xu87c13502020-08-06 12:17:26 +0200443 upscaled_shape = np.array(op.inputs[0].shape[1:3])
444 out_shape = np.array(op.outputs[0].shape[1:3])
445 while (upscaled_shape < out_shape).all():
446 upscaled_shape *= 2
447 if op.attrs["align_corners"]:
448 upscaled_shape -= 1
449 if np.array_equal(out_shape, upscaled_shape):
450 return True
451 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200452
Michael McGeagh1eeea512020-09-30 14:23:09 +0100453 @classmethod
454 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100455 # check data type
Dwight Lidman8359a472020-09-28 15:53:40 +0200456 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100457 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200458 print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
459 return False
460
461 # check batch size
462 batch_sizes = {1, 2, 4, 8}
463 if ifm_tensor.shape[0] not in batch_sizes:
464 print("Warning: only batch sizes {} supported for {}, placing on CPU".format(batch_sizes, op.type))
Tim Hall79d07d22020-04-27 18:20:16 +0100465 return False
466
Michael McGeagh1eeea512020-09-30 14:23:09 +0100467 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200468 return False
469
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200470 # check non const weights
471 if weight_tensor.values is None:
472 print("Warning:", op.type, "has non-const weights, placing on CPU")
473 return False
474
Tim Hall79d07d22020-04-27 18:20:16 +0100475 return True
476
Michael McGeagh1eeea512020-09-30 14:23:09 +0100477 @classmethod
478 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100479 # check data type
480 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200481 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100482 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100483 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100484 and ifm_tensor.dtype != ofm_tensor.dtype
485 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200486 print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100487 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100488 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200489 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100490 if ifm_tensor.dtype != ifm2_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200491 print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200492 return False
493 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100494 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200495 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100496 if ofm_tensor.dtype.type & BaseType.Unsigned:
Dwight Lidman8359a472020-09-28 15:53:40 +0200497 print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200498 return False
499 # and 8, 16 or 32-bit
Dwight Lidman8359a472020-09-28 15:53:40 +0200500 bit_lengths = {8, 16, 32}
501 if ofm_tensor.element_size() * 8 not in bit_lengths:
502 print(
503 "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
504 )
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200505 return False
506 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100507 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
508 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
509 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200510 print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200511 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200512 elif op.type in cls.binary_elem_wise_shift_ops:
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200513 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200514 print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200515 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200516 if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200517 print("Warning:", op.type, "output datatype is not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200518 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100519
520 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200521 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200522 print(
523 "Warning:",
524 op.type,
525 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
526 )
Tim Hallc30f4952020-06-15 20:47:35 +0100527 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100528 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200529 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200530 print(
531 "Warning:",
532 op.type,
533 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
534 )
Dwight Lidmanf995db72020-04-27 11:15:12 +0200535 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200536
537 # negative alpha values are not supported
Louis Verhaardaee5d752020-09-30 09:01:52 +0200538 if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
Dwight Lidman8359a472020-09-28 15:53:40 +0200539 print("Warning:", op.type, "has negative alpha, placing on CPU")
Dwight Lidman332a7042020-06-11 15:32:42 +0200540 return False
541
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200542 # check if ifm or ifm2 has ofm shape
543 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200544 print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200545 return False
546
Michael McGeagh1eeea512020-09-30 14:23:09 +0100547 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200548 return False
549
Tim Hall79d07d22020-04-27 18:20:16 +0100550 return True
551
Michael McGeagh1eeea512020-09-30 14:23:09 +0100552 @classmethod
553 def check_memory_only_restrictions(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200554 if op.type == Op.StridedSlice:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200555 if len(op.inputs) != 4:
556 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100557 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200558 input_tens, begin_tens, end_tens, strides_tens = op.inputs
559 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
560 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
561 return False
562 if not (
563 len(input_tens.shape)
564 == len(op.outputs[0].shape)
565 == len(begin_tens.values)
566 == len(end_tens.values)
567 == len(strides_tens.values)
568 ):
569 warn_cpu(op, "has input tensors with shapes that are not supported")
570 return False
571 # check stride size
572 if any(stride != 1 for stride in strides_tens.values):
573 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100574 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200575 # check ellipsis_mask
576 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200577 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200578 return False
579 # check if both new_axis_mask and shrink_axis_mask have bit set
580 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200581 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
582 return False
583 # Calculate offset start/end
584 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
585 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
586 # check "end - begin" doesn't result in any zero or negative elements
587 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
588 warn_cpu(
589 op,
590 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
591 begin_tens.values, end_tens.values
592 ),
593 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200594 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200595 if op.type == Op.SplitV:
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200596 # check that maximum one size is set to -1, indicating that size should be inferred
597 sizes = op.inputs[1].values
598 num_to_be_inferred = 0
599 for size in sizes:
600 if size == -1:
601 num_to_be_inferred += 1
602
603 if num_to_be_inferred > 1:
604 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
605 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200606 if op.type in set((Op.Concat, Op.ConcatTFLite,)):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200607 axis = op.attrs.get("axis", None)
608 if axis is None:
609 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
610 return False
611 if axis < 0:
612 axis += len(op.inputs[0].shape)
Patrik Gustavsson36ad73a2020-10-06 13:58:24 +0200613 if not 0 <= axis < len(op.inputs[0].shape):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200614 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
615 return False
616 ofm = op.outputs[0]
617 ofm_dims = len(ofm.shape)
618 for ifm in op.inputs:
619 if len(ifm.shape) != ofm_dims:
620 return False
621 for i in range(ofm_dims):
622 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200623 print(
624 "Warning:",
625 op.type,
626 "invalid ifm:",
627 ifm.name,
628 ifm.shape,
629 "mismatch in dimension",
630 i,
631 ", placing on CPU",
632 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200633 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200634
Tim Hall79d07d22020-04-27 18:20:16 +0100635 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200636
Michael McGeagh1eeea512020-09-30 14:23:09 +0100637 @classmethod
638 def check_quantization_restrictions_binary_elem_wise(cls, op):
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200639 # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
Tim Halle3786ac2020-07-28 17:40:50 +0100640 assert len(op.inputs) >= 2 and len(op.outputs) == 1
641
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200642 if (
Tim Halle3786ac2020-07-28 17:40:50 +0100643 op.inputs[0].quantization is None
Michael McGeagh34ad19b2020-09-04 15:44:23 +0100644 or not op.inputs[0].is_scaling_equal(op.inputs[1])
645 or not op.inputs[0].is_scaling_equal(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200646 ):
647 print(
648 "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
649 )
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200650 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100651
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200652 return True
653
Michael McGeagh1eeea512020-09-30 14:23:09 +0100654 @classmethod
655 def check_activation_ops(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200656 if op.type == Op.Softmax:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200657 ifm_tensor = op.inputs[0]
658 ofm_tensor = op.outputs[0]
659
660 # check data type
661 if ifm_tensor.dtype != ofm_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200662 print("Warning:", op.type, "input type differs from output type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200663 return False
664
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200665 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
Dwight Lidman8359a472020-09-28 15:53:40 +0200666 print(
667 "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
668 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200669 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200670
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200671 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100672 if ifm_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200673 print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200674 return False
675
676 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200677
Michael McGeagh1eeea512020-09-30 14:23:09 +0100678 @classmethod
679 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200680 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200681 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Dwight Lidman8359a472020-09-28 15:53:40 +0200682 print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200683 return False
684
685 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200686 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100687 for quant_value in bias_tensor.quant_values:
688 if not (-(1 << 39) <= quant_value < (1 << 39)):
Dwight Lidman8359a472020-09-28 15:53:40 +0200689 print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200690 return False
691
692 return True