blob: fbb306e89f6ab6945af5b13ae901cc8bf264e944 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Charles Xu87c13502020-08-06 12:17:26 +020018import numpy as np
19
Tim Hallc30f4952020-06-15 20:47:35 +010020from .data_type import BaseType
21from .data_type import DataType
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020022from .operation import get_slice_offsets
23
24
Michael McGeagh37ded342020-10-01 15:37:44 +010025# Custom decorator function to allow formatting docstrings containing "{}"
26def docstring_format_args(args):
27 def docstring(func):
28 func.__doc__ = func.__doc__.format(*args)
29 return func
30
31 return docstring
32
33
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020034def warn_cpu(op, msg):
35 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010036
37
38class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010039 # Categorised lists of supported operators
40 npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
41 convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
42 depthwise_convolution_ops = set(("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D",))
43 transpose_convolution_ops = set(("Conv2DBackpropInput",))
44 max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
45 avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
46 pooling_ops = set(("ReduceSum",)) | max_pooling_ops | avg_pooling_ops
47 resizing_ops = set(("ResizeBilinear",))
48 fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
49 mac_main_ops = (
50 # RNN/LSTM/GRU
51 set(("BlockLSTM",))
52 # convolutions
53 | convolution_ops
54 # depth-wise convolutions
55 | depthwise_convolution_ops
56 # transpose convolutions
57 | transpose_convolution_ops
58 # pooling
59 | pooling_ops
60 # resizing/upscaling
61 | resizing_ops
62 # FC layers
63 | fc_vector_products
64 )
65 unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
66 binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
67 binary_elem_wise_shift_ops = set(("SHL", "SHR",))
68 binary_elem_wise_add_mul_sub = set(
69 ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
70 )
71 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
72 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010073 supported_int32_tensor_ops = (
74 set(("Requantize", "ReduceSum", "CLZ",)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
75 )
Michael McGeagh1eeea512020-09-30 14:23:09 +010076 activation_ops = set(
77 (
78 "QuantizedRelu",
79 "QuantizedRelu1",
80 "QuantizedRelu6",
81 "Relu",
82 "Relu6",
83 "ReluN1To1",
84 "Sigmoid",
85 "Tanh",
86 "Softmax",
87 )
88 )
89 npu_post_ops = (
90 # concatenation write direction
91 set(("ConcatSliceWrite",))
92 # bias add and batch norm
93 | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
94 # Quantization
95 | set(("Quantize",))
96 # activation functions
97 | activation_ops
98 )
99 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
100 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
101 memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | concat_ops | split_ops
102 shapeless_input_ops = set(("Split", "SplitV",)) | binary_elem_wise_main_ops
103 supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
104 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh37ded342020-10-01 15:37:44 +0100105 supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
106 # Defined ranges for allowed values:
107 tens_dim_range = (1, 65535)
Michael McGeagh1eeea512020-09-30 14:23:09 +0100108
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200109 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +0100110 # Setup supported operator restriction checkers
111 self.supported_operator_restrictions = {}
112 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100113 {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100114 )
115 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100116 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100117 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200118 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100119 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100120 )
121 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100122 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100123 )
124 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100125 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100126 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100127 self.supported_operator_restrictions.update(
128 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
129 )
130 self.supported_operator_restrictions.update(
131 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
132 )
133 self.supported_operator_restrictions.update(
134 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
135 )
136 self.supported_operator_restrictions.update(
137 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
138 )
Michael McGeagh37ded342020-10-01 15:37:44 +0100139 # Setup the generic constraints
140 self.generic_constraints = []
141 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
142 self.generic_constraints.append(SupportedOperators.constraint_tens_shapeless)
143 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
144 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
145 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
146 self.generic_constraints.append(SupportedOperators.constraint_faf)
147 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
Tim Hall79d07d22020-04-27 18:20:16 +0100148
149 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100150 if op.type not in SupportedOperators.supported_operators:
Tim Hall79d07d22020-04-27 18:20:16 +0100151 return False
Michael McGeagh37ded342020-10-01 15:37:44 +0100152 for constraint in self.generic_constraints:
153 valid, extra = constraint(op)
154 if not valid:
155 print('Warning: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type))
156 print(" - {}".format(constraint.__doc__))
157 if extra:
158 print(" {}".format(extra))
159 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100160 if op.type in self.supported_operator_restrictions:
161 return self.supported_operator_restrictions[op.type](op)
162 return True
163
Michael McGeagh37ded342020-10-01 15:37:44 +0100164 @staticmethod
165 def constraint_tens_defined_shape(op):
166 "Input(s) and Output Tensors must have a defined shape"
167 valid = True
168 extra = []
169 for tens in op.inputs + op.outputs:
170 if tens:
171 valid &= tens.has_fully_defined_shape()
172 extra.append("shape={}".format(tens.shape))
173 return valid, " ".join(extra)
174
Michael McGeagh1eeea512020-09-30 14:23:09 +0100175 @classmethod
Michael McGeagh37ded342020-10-01 15:37:44 +0100176 @docstring_format_args([shapeless_input_ops])
177 def constraint_tens_shapeless(cls, op):
178 "Scalar or Broadcasting Tensors are only valid for Input Tensors, and when op type is: {}"
179 valid = True
180 extra = []
181 for tens in op.inputs:
182 if tens and tens.shape == []:
183 valid &= op.type in cls.shapeless_input_ops
184 extra.append("shape={}".format(tens.shape))
185 for tens in op.outputs:
186 if tens.shape == []:
187 valid = False
188 extra.append("shape={}".format(tens.shape))
189 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100190
Michael McGeagh37ded342020-10-01 15:37:44 +0100191 @staticmethod
192 def constraint_tens_shape_size(op):
193 "Input(s) and Output Tensors must not be greater than 4D"
194 valid = True
195 extra = []
196 for tens in op.inputs + op.outputs:
197 if tens:
198 valid &= len(tens.shape) <= 4
199 extra.append("shape={}".format(tens.shape))
200 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100201
Michael McGeagh37ded342020-10-01 15:37:44 +0100202 @classmethod
203 @docstring_format_args([supported_dtypes, supported_int32_tensor_ops])
204 def constraint_tens_dtype(cls, op):
205 "Tensors must be of type: {}. Tensors which are int32 are only valid when op type is: {}"
206 valid = True
207 extra = []
208 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
209 tensors = tensors if tensors else op.inputs
210 for tens in tensors:
211 if tens.dtype == DataType.int32:
212 valid &= op.type in cls.supported_int32_tensor_ops
213 else:
214 valid &= tens.dtype in cls.supported_dtypes
215 extra.append("dtype={}".format(tens.dtype))
216 return valid, " ".join(extra)
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200217
Michael McGeagh37ded342020-10-01 15:37:44 +0100218 @classmethod
219 @docstring_format_args(tens_dim_range)
220 def constraint_tens_dimension(cls, op):
221 "Tensor dimensions must be in the range {}-{} (inclusive)"
222 tens_min, tens_max = cls.tens_dim_range
223 valid = True
224 extra = []
225 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
226 tensors = tensors if tensors else op.inputs
227 for tens in tensors:
228 valid &= all(tens_min <= dim <= tens_max for dim in tens.shape)
229 extra.append("shape={}".format(tens.shape))
230 return valid, " ".join(extra)
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200231
Michael McGeagh37ded342020-10-01 15:37:44 +0100232 @classmethod
233 @docstring_format_args([supported_fused_activations])
234 def constraint_faf(cls, op):
235 "The fused activation function (if present) must be one of type: {}"
236 faf = op.attrs.get("fused_activation_function")
237 valid = (faf is None) or (faf in cls.supported_fused_activations)
238 extra = "fused_activation_function={}".format(faf)
239 return valid, extra
240
241 @staticmethod
242 def constraint_tens_quant_scale(op):
243 "Tensors with quantization scales must be finite"
244 valid = True
245 extra = []
246 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
247 for tens in tensors:
248 if tens.quantization is not None and tens.quantization.scale_f32 is not None:
249 valid &= not np.isinf(tens.quantization.scale_f32).any()
250 extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32))
251 return valid, " ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100252
Michael McGeagh1eeea512020-09-30 14:23:09 +0100253 @classmethod
254 def check_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100255 # check stride
Dwight Lidman0538a772020-05-06 14:09:17 +0200256 if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
Tim Hall79d07d22020-04-27 18:20:16 +0100257 return False
258
259 # check dilation
260 dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
261 dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
262 if dilation_w_factor > 2 or dilation_h_factor > 2:
263 return False
264
265 # check data type
Jacob Bohlin49d92122020-08-19 14:36:46 +0200266 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100267 if weight_tensor.element_size() > 1:
268 return False
269
Michael McGeagh1eeea512020-09-30 14:23:09 +0100270 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200271 return False
272
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200273 # check kernel size [HWIO]
274 dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1)
275 dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1)
276
277 if dilated_weight_w > 64 or dilated_weight_h > 64:
278 return False
279
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200280 # check non const weights
281 if weight_tensor.values is None:
282 print("Warning:", op.type, "has non-const weights, placing on CPU")
283 return False
284
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200285 # check weight sums over [HWI]
286 zero_point = weight_tensor.quantization.zero_point
287 quant_weights = weight_tensor.quant_values.astype(np.int64)
288 weights = quant_weights - zero_point
289 totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
290
291 if np.amax(totals) > 127 * 65536:
Tim Hall79d07d22020-04-27 18:20:16 +0100292 return False
293
294 # check batch size
295 if ifm_tensor.shape[0] != 1:
296 return False
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200297
Tim Hall79d07d22020-04-27 18:20:16 +0100298 return True
299
Michael McGeagh1eeea512020-09-30 14:23:09 +0100300 @classmethod
301 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100302 # check depth
303 ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
304 if op.attrs["depth_multiplier"] > 1 and not (
305 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
306 ):
307 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100308 return cls.check_convolution_restrictions(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100309
Michael McGeagh1eeea512020-09-30 14:23:09 +0100310 @classmethod
311 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200312 # check stride
313 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
314 if stride_h != stride_w != 2:
315 return False
316
317 # check output dimensions
318 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
319 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
320 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
321 if op.attrs["padding"] == b"SAME":
322 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
323 return False
324 elif op.attrs["padding"] == b"VALID":
325 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100326 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
327 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
328 ):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200329 return False
330
Michael McGeagh1eeea512020-09-30 14:23:09 +0100331 return cls.check_convolution_restrictions(op)
Jacob Bohlincf7da102020-05-20 09:03:40 +0200332
Michael McGeagh1eeea512020-09-30 14:23:09 +0100333 @classmethod
334 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100335 # check stride
Dwight Lidman0538a772020-05-06 14:09:17 +0200336 if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
Tim Hall79d07d22020-04-27 18:20:16 +0100337 return False
338
339 # check data type
340 ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
341 if ifm_tensor.dtype != ofm_tensor.dtype:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200342 if op.type != "ReduceSum":
343 return False
344 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100345
346 # check batch size
347 if ifm_tensor.shape[0] != 1:
348 return False
349
Michael McGeagh1eeea512020-09-30 14:23:09 +0100350 if op.type in cls.avg_pooling_ops:
Tim Hall79d07d22020-04-27 18:20:16 +0100351 # check kernel size
352 if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
353 return False
Tim Hallc30f4952020-06-15 20:47:35 +0100354 if op.attrs["padding"] == b"VALID" and (
355 op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256
356 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100357 return False
358
Michael McGeagh1eeea512020-09-30 14:23:09 +0100359 if op.type in cls.max_pooling_ops:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200360 # check kernel size (any padding)
361 if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
Tim Hall79d07d22020-04-27 18:20:16 +0100362 return False
363 return True
364
Michael McGeagh1eeea512020-09-30 14:23:09 +0100365 @classmethod
366 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200367 # check unsupported upscaling factor
368 if op.type == "ResizeBilinear":
Charles Xu9a03fdf2020-07-02 15:12:40 +0200369 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
370 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200371 if op.inputs[0].shape == op.outputs[0].shape:
372 return True
Charles Xu87c13502020-08-06 12:17:26 +0200373 upscaled_shape = np.array(op.inputs[0].shape[1:3])
374 out_shape = np.array(op.outputs[0].shape[1:3])
375 while (upscaled_shape < out_shape).all():
376 upscaled_shape *= 2
377 if op.attrs["align_corners"]:
378 upscaled_shape -= 1
379 if np.array_equal(out_shape, upscaled_shape):
380 return True
381 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200382
Michael McGeagh1eeea512020-09-30 14:23:09 +0100383 @classmethod
384 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100385 # check data type
Jacob Bohlin49d92122020-08-19 14:36:46 +0200386 _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100387 if weight_tensor.element_size() > 1:
388 return False
389
Michael McGeagh1eeea512020-09-30 14:23:09 +0100390 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200391 return False
392
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200393 # check non const weights
394 if weight_tensor.values is None:
395 print("Warning:", op.type, "has non-const weights, placing on CPU")
396 return False
397
Tim Hall79d07d22020-04-27 18:20:16 +0100398 return True
399
Michael McGeagh1eeea512020-09-30 14:23:09 +0100400 @classmethod
401 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100402 # check data type
403 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200404 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100405 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100406 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100407 and ifm_tensor.dtype != ofm_tensor.dtype
408 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100409 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100410 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200411 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100412 if ifm_tensor.dtype != ifm2_tensor.dtype:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200413 return False
414 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100415 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200416 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100417 if ofm_tensor.dtype.type & BaseType.Unsigned:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200418 return False
419 # and 8, 16 or 32-bit
Tim Hallc30f4952020-06-15 20:47:35 +0100420 if ofm_tensor.element_size() not in (1, 2, 4):
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200421 return False
422 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100423 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
424 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
425 ):
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200426 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100427 elif op.type in cls.binary_elem_wise_shift_ops | set(("CLZ")):
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200428 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
429 return False
430 if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
431 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100432
433 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200434 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Tim Hallc30f4952020-06-15 20:47:35 +0100435 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100436 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200437 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
438 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200439
440 # negative alpha values are not supported
441 if op.type == "LeakyRelu" and op.attrs["alpha"] < 0:
442 return False
443
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200444 # check if ifm or ifm2 has ofm shape
445 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
446 return False
447
Michael McGeagh1eeea512020-09-30 14:23:09 +0100448 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200449 return False
450
Tim Hall79d07d22020-04-27 18:20:16 +0100451 return True
452
Michael McGeagh1eeea512020-09-30 14:23:09 +0100453 @classmethod
454 def check_memory_only_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100455 if op.type == "StridedSlice":
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200456 if len(op.inputs) != 4:
457 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100458 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200459 input_tens, begin_tens, end_tens, strides_tens = op.inputs
460 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
461 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
462 return False
463 if not (
464 len(input_tens.shape)
465 == len(op.outputs[0].shape)
466 == len(begin_tens.values)
467 == len(end_tens.values)
468 == len(strides_tens.values)
469 ):
470 warn_cpu(op, "has input tensors with shapes that are not supported")
471 return False
472 # check stride size
473 if any(stride != 1 for stride in strides_tens.values):
474 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100475 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200476 # check ellipsis_mask
477 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200478 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200479 return False
480 # check if both new_axis_mask and shrink_axis_mask have bit set
481 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200482 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
483 return False
484 # Calculate offset start/end
485 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
486 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
487 # check "end - begin" doesn't result in any zero or negative elements
488 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
489 warn_cpu(
490 op,
491 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
492 begin_tens.values, end_tens.values
493 ),
494 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200495 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200496 if op.type == "SplitV":
497 # check that maximum one size is set to -1, indicating that size should be inferred
498 sizes = op.inputs[1].values
499 num_to_be_inferred = 0
500 for size in sizes:
501 if size == -1:
502 num_to_be_inferred += 1
503
504 if num_to_be_inferred > 1:
505 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
506 return False
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200507 if op.type.find("Concat") != -1:
508 axis = op.attrs.get("axis", None)
509 if axis is None:
510 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
511 return False
512 if axis < 0:
513 axis += len(op.inputs[0].shape)
514 if not 0 < axis < len(op.inputs[0].shape):
515 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
516 return False
517 ofm = op.outputs[0]
518 ofm_dims = len(ofm.shape)
519 for ifm in op.inputs:
520 if len(ifm.shape) != ofm_dims:
521 return False
522 for i in range(ofm_dims):
523 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200524 print(
525 "Warning:",
526 op.type,
527 "invalid ifm:",
528 ifm.name,
529 ifm.shape,
530 "mismatch in dimension",
531 i,
532 ", placing on CPU",
533 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200534 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200535
Tim Hall79d07d22020-04-27 18:20:16 +0100536 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200537
Michael McGeagh1eeea512020-09-30 14:23:09 +0100538 @classmethod
539 def check_quantization_restrictions_binary_elem_wise(cls, op):
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200540 # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
Tim Halle3786ac2020-07-28 17:40:50 +0100541 assert len(op.inputs) >= 2 and len(op.outputs) == 1
542
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200543 if (
Tim Halle3786ac2020-07-28 17:40:50 +0100544 op.inputs[0].quantization is None
Michael McGeagh34ad19b2020-09-04 15:44:23 +0100545 or not op.inputs[0].is_scaling_equal(op.inputs[1])
546 or not op.inputs[0].is_scaling_equal(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200547 ):
548 print(
549 "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
550 )
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200551 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100552
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200553 return True
554
Michael McGeagh1eeea512020-09-30 14:23:09 +0100555 @classmethod
556 def check_activation_ops(cls, op):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200557 if op.type == "Softmax":
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200558 ifm_tensor = op.inputs[0]
559 ofm_tensor = op.outputs[0]
560
561 # check data type
562 if ifm_tensor.dtype != ofm_tensor.dtype:
563 return False
564
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200565 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
566 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200567
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200568 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100569 if ifm_tensor.shape != ofm_tensor.shape:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200570 return False
571
572 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200573
Michael McGeagh1eeea512020-09-30 14:23:09 +0100574 @classmethod
575 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200576 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200577 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200578 return False
579
580 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200581 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100582 for quant_value in bias_tensor.quant_values:
583 if not (-(1 << 39) <= quant_value < (1 << 39)):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200584 return False
585
586 return True