blob: 3888b801e6d746f253cf1b20019c932b9aa82464 [file] [log] [blame]
Georgios Pinitasa7171102021-08-17 12:54:59 +01001#!/usr/bin/env python3
2# Copyright (c) 2021 Arm Limited.
3#
4# SPDX-License-Identifier: MIT
5#
6# Permission is hereby granted, free of charge, to any person obtaining a copy
7# of this software and associated documentation files (the "Software"), to
8# deal in the Software without restriction, including without limitation the
9# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10# sell copies of the Software, and to permit persons to whom the Software is
11# furnished to do so, subject to the following conditions:
12#
13# The above copyright notice and this permission notice shall be included in all
14# copies or substantial portions of the Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22# SOFTWARE.
23import json
24import logging
25import os
26import sys
27from argparse import ArgumentParser
28
29import tflite
30
31sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
32
33from utils.model_identification import identify_model_type
34from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name
35
36SUPPORTED_MODEL_TYPES = ["tflite"]
37logger = logging.getLogger("report_model_ops")
38
39
40def get_ops_from_tflite_graph(model):
41 """
42 Helper function that extract operator related meta-data from a TfLite model
43
44 Parameters
45 ----------
46 model: str
47 Respective TfLite model to analyse
48
49 Returns
50 ----------
51 supported_ops, unsupported_ops, data_types: tuple
52 A tuple with the sets of unique operator types and data-types that are present in the model
53 """
54
55 logger.debug(f"Analysing TfLite mode '{model}'!")
56
57 with open(model, "rb") as f:
58 buf = f.read()
59 model = tflite.Model.GetRootAsModel(buf, 0)
60
61 # Extract unique operators
62 nr_unique_ops = model.OperatorCodesLength()
63 unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)}
64
65 # Extract IO data-types
66 data_types = set()
67 for subgraph_id in range(0, model.SubgraphsLength()):
68 subgraph = model.Subgraphs(subgraph_id)
69 for tensor_id in range(0, subgraph.TensorsLength()):
70 data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type()))
71
72 # Perform mapping between TfLite ops to ComputeLibrary ones
73 supported_ops = set()
74 unsupported_ops = set()
75 for top in unique_ops:
76 try:
77 supported_ops.add(tflite_op2acl(top))
78 except:
79 unsupported_ops.add(top)
80 logger.warning(f"Operator {top} has not ComputeLibrary mapping")
81
82 return (supported_ops, unsupported_ops, data_types)
83
84
85def extract_model_meta(model, model_type):
86 """
87 Function that calls the appropriate model parser to extract model related meta-data
88 Supported parsers: TfLite
89
90 Parameters
91 ----------
92 model: str
93 Path to model that we want to analyze
94 model_type:
95 type of the model
96
97 Returns
98 ----------
99 ops, data_types: (tuple)
100 A tuple with the list of unique operator types and data-types that are present in the model
101 """
102
103 if model_type == "tflite":
104 return get_ops_from_tflite_graph(model)
105 else:
106 logger.warning(f"Model type '{model_type}' is unsupported!")
107 return ()
108
109
110def generate_build_config(ops, data_types):
111 """
112 Function that generates a compatible ComputeLibrary operator-based build configuration
113
114 Parameters
115 ----------
116 ops: set
117 Set with the operators to add in the build configuration
118 data_types:
119 Set with the data types to add in the build configuration
120
121 Returns
122 ----------
123 config_data: dict
124 Dictionary compatible with ComputeLibrary
125 """
126 config_data = {}
127 config_data["operators"] = list(ops)
128 config_data["data_types"] = list(data_types)
129
130 return config_data
131
132
133if __name__ == "__main__":
134 parser = ArgumentParser(
135 description="""Report map of operations in a list of models.
136 The script consumes deep learning models and reports the type of operations and data-types used
137 Supported model types: TfLite """
138 )
139
140 parser.add_argument(
141 "-m",
142 "--models",
143 nargs="+",
144 required=True,
145 type=str,
146 help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}",
147 )
148 parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output")
149 parser.add_argument(
150 "-c",
151 "--config",
152 type=str,
153 help="JSON configuration file used that can be used for custom ComputeLibrary builds",
154 )
155 args = parser.parse_args()
156
157 # Setup Logger
158 logging_level = logging.INFO
159 if args.debug:
160 logging_level = logging.DEBUG
161 logging.basicConfig(level=logging_level)
162
163 # Extract operator mapping
164 final_supported_ops = set()
165 final_unsupported_ops = set()
166 final_dts = set()
167 for model in args.models:
168 logger.debug(f"Starting analyzing {model} model")
169
170 model_type = identify_model_type(model)
171 supported_model_ops, unsupported_mode_ops, model_dts = extract_model_meta(model, model_type)
172 final_supported_ops.update(supported_model_ops)
173 final_unsupported_ops.update(unsupported_mode_ops)
174 final_dts.update(model_dts)
175
176 logger.info("=== Supported Operators")
177 logger.info(final_supported_ops)
178 logger.info("=== Unsupported Operators")
179 logger.info(final_unsupported_ops)
180 logger.info("=== Data Types")
181 logger.info(final_dts)
182
183 # Generate json file
184 if args.config:
185 logger.debug("Generating JSON build configuration file")
186 config_data = generate_build_config(final_supported_ops, final_dts)
187 with open(args.config, "w") as f:
188 json.dump(config_data, f)