Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 1 | # Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # Description: |
| 17 | # The TosaSupportedOperators class which is a collection of all supported operators and parameter checks. |
| 18 | from collections import defaultdict |
| 19 | |
| 20 | from .data_type import DataType |
| 21 | from .operation import Op |
| 22 | from .supported_operators_util import docstring_format_args |
| 23 | from .supported_operators_util import list_formatter |
| 24 | from .tosa_mapping import optype_to_tosa_op_type |
| 25 | |
| 26 | |
| 27 | class TosaSupportedOperators: |
| 28 | # TODO currently sparsely populated |
| 29 | # Categorised lists of supported operators |
| 30 | convolution_ops = set((Op.Conv2DBias,)) |
| 31 | convolution_like_ops = convolution_ops |
Patrik Gustavsson | c74682c | 2021-08-17 14:26:38 +0200 | [diff] [blame^] | 32 | max_pooling_ops = Op.op_set(Op.is_maxpool_op) |
| 33 | avg_pooling_ops = Op.op_set(Op.is_avgpool_op) |
| 34 | pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops |
| 35 | |
| 36 | mac_main_ops = convolution_like_ops | pooling_ops |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 37 | |
| 38 | type_conversion_ops = set((Op.Rescale,)) |
Patrik Gustavsson | 5e26eda | 2021-06-30 09:07:16 +0200 | [diff] [blame] | 39 | relu_ops = set((Op.Clamp, Op.ReluN,)) |
Patrik Gustavsson | 8f1f9aa | 2021-06-28 07:41:58 +0200 | [diff] [blame] | 40 | activation_ops = relu_ops |
| 41 | |
| 42 | npu_post_ops = activation_ops |
| 43 | supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops |
| 44 | |
| 45 | # Supported data types |
| 46 | # TODO will differ compared to TensorFlow Lite, currently set to the same |
| 47 | supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32)) |
| 48 | |
| 49 | def __init__(self): |
| 50 | # Setup the generic constraints. Note: the order matters |
| 51 | self.generic_constraints = [] |
| 52 | self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype) |
| 53 | |
| 54 | # Setup specific constraints. Note: the order matters |
| 55 | self.specific_constraints = defaultdict(list) |
| 56 | |
| 57 | def is_operator_supported(self, op): |
| 58 | ext_type = optype_to_tosa_op_type(op.type) |
| 59 | if op.type not in TosaSupportedOperators.supported_operators: |
| 60 | if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const): |
| 61 | print(f"Info: {ext_type} '{op.name}' is not a NPU op") |
| 62 | return False |
| 63 | |
| 64 | for constraint in self.generic_constraints + self.specific_constraints[op.type]: |
| 65 | valid, extra = constraint(op) |
| 66 | if not valid: |
| 67 | print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU") |
| 68 | print(f" - {constraint.__doc__}") |
| 69 | if extra: |
| 70 | print(f" {extra}") |
| 71 | return False |
| 72 | |
| 73 | return True |
| 74 | |
| 75 | # TODO this function is the same for TensorFlow Lite, but input might differ |
| 76 | @classmethod |
| 77 | @docstring_format_args([list_formatter(supported_op_dtypes)]) |
| 78 | def constraint_tens_dtype(cls, op): |
| 79 | "Tensors must be of type: {}" |
| 80 | valid = True |
| 81 | extra = [] |
| 82 | tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] |
| 83 | if not tensors: |
| 84 | tensors = [tens for tens in op.inputs if tens] |
| 85 | for tens in tensors: |
| 86 | if tens.dtype not in cls.supported_op_dtypes: |
| 87 | valid = False |
| 88 | extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}") |
| 89 | return valid, ", ".join(extra) |