blob: 5cd186c67957ac4de6469332eaa18124d4b2404b [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# The TosaSemantic class which is a collection of TOSA model semantic checks.
18from collections import defaultdict
19
20from .operation import Op
21from .tosa_mapping import optype_to_tosa_op_type
22
23
24class TosaSemantic:
25 # TODO populate this
26
27 def __init__(self):
28 # Setup the generic constraints. Note: the order matters
29 self.generic_constraints = []
30
31 # Setup specific constraints. Note: the order matters
32 self.specific_constraints = defaultdict(list)
33
34 def is_operator_semantic_valid(self, op):
35 ext_type = optype_to_tosa_op_type(op.type)
36
37 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
38 return True
39
40 for constraint in self.generic_constraints + self.specific_constraints[op.type]:
41 valid, extra = constraint(op)
42 if not valid:
43 print(f"Warning: unsupported TOSA semantics for {ext_type} '{op.name}'.")
44 print(f" - {constraint.__doc__}")
45 if extra:
46 print(f" {extra}")
47 return False
48
49 return True
50
51
52def tosa_semantic_checker(nng):
53 semantic_checker = TosaSemantic()
54 for sg in nng.subgraphs:
55 for op in sg.get_all_ops():
56 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
57 return nng