blob: 1549005da5e17d5499d433b6b0805ef756e1e6a9 [file] [log] [blame]
Georgios Pinitasa7171102021-08-17 12:54:59 +01001#!/usr/bin/env python3
2# Copyright (c) 2021 Arm Limited.
3#
4# SPDX-License-Identifier: MIT
5#
6# Permission is hereby granted, free of charge, to any person obtaining a copy
7# of this software and associated documentation files (the "Software"), to
8# deal in the Software without restriction, including without limitation the
9# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10# sell copies of the Software, and to permit persons to whom the Software is
11# furnished to do so, subject to the following conditions:
12#
13# The above copyright notice and this permission notice shall be included in all
14# copies or substantial portions of the Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22# SOFTWARE.
23import json
24import logging
25import os
26import sys
27from argparse import ArgumentParser
28
29import tflite
30
31sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
32
33from utils.model_identification import identify_model_type
Freddie Liardet487d3902021-09-21 12:36:43 +010034from utils.tflite_helpers import tflite_op2acl, tflite_typecode2name, tflite_typecode2aclname
Georgios Pinitasa7171102021-08-17 12:54:59 +010035
36SUPPORTED_MODEL_TYPES = ["tflite"]
37logger = logging.getLogger("report_model_ops")
38
39
Freddie Liardet487d3902021-09-21 12:36:43 +010040def get_ops_types_from_tflite_graph(model):
Georgios Pinitasa7171102021-08-17 12:54:59 +010041 """
Freddie Liardet487d3902021-09-21 12:36:43 +010042 Helper function that extract operator related meta-data from a TFLite model
Georgios Pinitasa7171102021-08-17 12:54:59 +010043
44 Parameters
45 ----------
46 model: str
Freddie Liardet487d3902021-09-21 12:36:43 +010047 Respective TFLite model to analyse
Georgios Pinitasa7171102021-08-17 12:54:59 +010048
49 Returns
50 ----------
51 supported_ops, unsupported_ops, data_types: tuple
52 A tuple with the sets of unique operator types and data-types that are present in the model
53 """
54
Freddie Liardet487d3902021-09-21 12:36:43 +010055 logger.debug(f"Analysing TFLite mode '{model}'!")
Georgios Pinitasa7171102021-08-17 12:54:59 +010056
57 with open(model, "rb") as f:
58 buf = f.read()
59 model = tflite.Model.GetRootAsModel(buf, 0)
60
61 # Extract unique operators
62 nr_unique_ops = model.OperatorCodesLength()
63 unique_ops = {tflite.opcode2name(model.OperatorCodes(op_id).BuiltinCode()) for op_id in range(0, nr_unique_ops)}
64
65 # Extract IO data-types
Freddie Liardet487d3902021-09-21 12:36:43 +010066 supported_data_types = set()
67 unsupported_data_types = set()
Georgios Pinitasa7171102021-08-17 12:54:59 +010068 for subgraph_id in range(0, model.SubgraphsLength()):
69 subgraph = model.Subgraphs(subgraph_id)
70 for tensor_id in range(0, subgraph.TensorsLength()):
Freddie Liardet487d3902021-09-21 12:36:43 +010071 try:
72 supported_data_types.add(tflite_typecode2aclname(subgraph.Tensors(tensor_id).Type()))
73 except ValueError:
74 unsupported_data_types.add(tflite_typecode2name(subgraph.Tensors(tensor_id).Type()))
75 logger.warning(f"Data type {tflite_typecode2name(subgraph.Tensors(tensor_id).Type())} is not supported by ComputeLibrary")
Georgios Pinitasa7171102021-08-17 12:54:59 +010076
77 # Perform mapping between TfLite ops to ComputeLibrary ones
78 supported_ops = set()
79 unsupported_ops = set()
80 for top in unique_ops:
81 try:
82 supported_ops.add(tflite_op2acl(top))
Freddie Liardet487d3902021-09-21 12:36:43 +010083 except ValueError:
Georgios Pinitasa7171102021-08-17 12:54:59 +010084 unsupported_ops.add(top)
Freddie Liardet487d3902021-09-21 12:36:43 +010085 logger.warning(f"Operator {top} does not have ComputeLibrary mapping")
Georgios Pinitasa7171102021-08-17 12:54:59 +010086
Freddie Liardet487d3902021-09-21 12:36:43 +010087 return (supported_ops, unsupported_ops, supported_data_types, unsupported_data_types)
Georgios Pinitasa7171102021-08-17 12:54:59 +010088
89
90def extract_model_meta(model, model_type):
91 """
92 Function that calls the appropriate model parser to extract model related meta-data
Freddie Liardet487d3902021-09-21 12:36:43 +010093 Supported parsers: TFLite
Georgios Pinitasa7171102021-08-17 12:54:59 +010094
95 Parameters
96 ----------
97 model: str
98 Path to model that we want to analyze
99 model_type:
100 type of the model
101
102 Returns
103 ----------
104 ops, data_types: (tuple)
105 A tuple with the list of unique operator types and data-types that are present in the model
106 """
107
108 if model_type == "tflite":
Freddie Liardet487d3902021-09-21 12:36:43 +0100109 return get_ops_types_from_tflite_graph(model)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100110 else:
111 logger.warning(f"Model type '{model_type}' is unsupported!")
112 return ()
113
114
Freddie Liardet487d3902021-09-21 12:36:43 +0100115def generate_build_config(ops, data_types, data_layouts):
Georgios Pinitasa7171102021-08-17 12:54:59 +0100116 """
117 Function that generates a compatible ComputeLibrary operator-based build configuration
118
119 Parameters
120 ----------
121 ops: set
122 Set with the operators to add in the build configuration
123 data_types:
124 Set with the data types to add in the build configuration
Freddie Liardet487d3902021-09-21 12:36:43 +0100125 data_layouts:
126 Set with the data layouts to add in the build configuration
Georgios Pinitasa7171102021-08-17 12:54:59 +0100127
128 Returns
129 ----------
130 config_data: dict
131 Dictionary compatible with ComputeLibrary
132 """
133 config_data = {}
134 config_data["operators"] = list(ops)
135 config_data["data_types"] = list(data_types)
Freddie Liardet487d3902021-09-21 12:36:43 +0100136 config_data["data_layouts"] = list(data_layouts)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100137
138 return config_data
139
140
141if __name__ == "__main__":
142 parser = ArgumentParser(
143 description="""Report map of operations in a list of models.
144 The script consumes deep learning models and reports the type of operations and data-types used
Freddie Liardet487d3902021-09-21 12:36:43 +0100145 Supported model types: TFLite """
Georgios Pinitasa7171102021-08-17 12:54:59 +0100146 )
147
148 parser.add_argument(
149 "-m",
150 "--models",
151 nargs="+",
152 required=True,
153 type=str,
154 help=f"List of models; supported model types: {SUPPORTED_MODEL_TYPES}",
155 )
156 parser.add_argument("-D", "--debug", action="store_true", help="Enable script debugging output")
157 parser.add_argument(
158 "-c",
159 "--config",
160 type=str,
161 help="JSON configuration file used that can be used for custom ComputeLibrary builds",
162 )
163 args = parser.parse_args()
164
165 # Setup Logger
166 logging_level = logging.INFO
167 if args.debug:
168 logging_level = logging.DEBUG
169 logging.basicConfig(level=logging_level)
170
171 # Extract operator mapping
172 final_supported_ops = set()
173 final_unsupported_ops = set()
Freddie Liardet487d3902021-09-21 12:36:43 +0100174 final_supported_dts = set()
175 final_unsupported_dts = set()
176 final_layouts = {"nhwc"} # Data layout for TFLite is always NHWC
Georgios Pinitasa7171102021-08-17 12:54:59 +0100177 for model in args.models:
178 logger.debug(f"Starting analyzing {model} model")
179
180 model_type = identify_model_type(model)
Freddie Liardet487d3902021-09-21 12:36:43 +0100181 supported_model_ops, unsupported_mode_ops, supported_model_dts, unsupported_model_dts = extract_model_meta(model, model_type)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100182 final_supported_ops.update(supported_model_ops)
183 final_unsupported_ops.update(unsupported_mode_ops)
Freddie Liardet487d3902021-09-21 12:36:43 +0100184 final_supported_dts.update(supported_model_dts)
185 final_unsupported_dts.update(unsupported_model_dts)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100186
187 logger.info("=== Supported Operators")
188 logger.info(final_supported_ops)
Freddie Liardet487d3902021-09-21 12:36:43 +0100189 if(len(final_unsupported_ops)):
190 logger.info("=== Unsupported Operators")
191 logger.info(final_unsupported_ops)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100192 logger.info("=== Data Types")
Freddie Liardet487d3902021-09-21 12:36:43 +0100193 logger.info(final_supported_dts)
194 if(len(final_unsupported_dts)):
195 logger.info("=== Unsupported Data Types")
196 logger.info(final_unsupported_dts)
197 logger.info("=== Data Layouts")
198 logger.info(final_layouts)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100199
Freddie Liardet487d3902021-09-21 12:36:43 +0100200 # Generate JSON file
Georgios Pinitasa7171102021-08-17 12:54:59 +0100201 if args.config:
202 logger.debug("Generating JSON build configuration file")
Freddie Liardet487d3902021-09-21 12:36:43 +0100203 config_data = generate_build_config(final_supported_ops, final_supported_dts, final_layouts)
Georgios Pinitasa7171102021-08-17 12:54:59 +0100204 with open(args.config, "w") as f:
205 json.dump(config_data, f)