blob: 933c18961fd22181f88df23a3e14c72c19bafde2 [file] [log] [blame]
Alex Tawsedaba3cf2023-09-29 15:55:38 +01001# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Utility script to generate model c file that can be included in the
18project directly. This should be called as part of cmake framework
19should the models need to be generated at configuration stage.
20"""
Alex Tawsedaba3cf2023-09-29 15:55:38 +010021import binascii
alexander3c798932021-03-26 21:42:19 +000022from argparse import ArgumentParser
23from pathlib import Path
Richard Burton17069622022-03-17 10:54:26 +000024
alexander3c798932021-03-26 21:42:19 +000025from jinja2 import Environment, FileSystemLoader
26
Alex Tawsedaba3cf2023-09-29 15:55:38 +010027from gen_utils import GenUtils
28
29# pylint: disable=duplicate-code
alexander3c798932021-03-26 21:42:19 +000030parser = ArgumentParser()
31
Alex Tawsedaba3cf2023-09-29 15:55:38 +010032parser.add_argument(
33 "--tflite_path",
34 help="Model (.tflite) path",
35 required=True
36)
37
38parser.add_argument(
39 "--output_dir",
40 help="Output directory",
41 required=True
42)
43
44parser.add_argument(
45 '-e',
46 '--expression',
47 action='append',
48 default=[],
49 dest="expr"
50)
51
52parser.add_argument(
53 '--header',
54 action='append',
55 default=[],
56 dest="headers"
57)
58
59parser.add_argument(
60 '-ns',
61 '--namespaces',
62 action='append',
63 default=[],
64 dest="namespaces"
65)
66
67parser.add_argument(
68 "--license_template",
69 type=str,
70 help="Header template file",
71 default="header_template.txt"
72)
73
74parsed_args = parser.parse_args()
alexander3c798932021-03-26 21:42:19 +000075
Richard Burton17069622022-03-17 10:54:26 +000076env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
alexander3c798932021-03-26 21:42:19 +000077 trim_blocks=True,
78 lstrip_blocks=True)
79
80
Alex Tawsedaba3cf2023-09-29 15:55:38 +010081# pylint: enable=duplicate-code
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010082def get_tflite_data(tflite_path: str) -> list:
83 """
84 Reads a binary file and returns a C style array as a
85 list of strings.
alexander3c798932021-03-26 21:42:19 +000086
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010087 Argument:
88 tflite_path: path to the tflite model.
alexander3c798932021-03-26 21:42:19 +000089
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010090 Returns:
91 list of strings
92 """
alexander3c798932021-03-26 21:42:19 +000093 with open(tflite_path, 'rb') as tflite_model:
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010094 data = tflite_model.read()
95
96 bytes_per_line = 32
97 hex_digits_per_line = bytes_per_line * 2
98 hexstream = binascii.hexlify(data).decode('utf-8')
99 hexstring = '{'
100
101 for i in range(0, len(hexstream), 2):
102 if 0 == (i % hex_digits_per_line):
103 hexstring += "\n"
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100104 hexstring += '0x' + hexstream[i:i + 2] + ", "
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +0100105
106 hexstring += '};\n'
107 return [hexstring]
alexander3c798932021-03-26 21:42:19 +0000108
109
110def main(args):
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100111 """
112 Generate models .cpp
113 @param args: Parsed args
114 """
Richard Burton17069622022-03-17 10:54:26 +0000115 if not Path(args.tflite_path).is_file():
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100116 raise ValueError(f"{args.tflite_path} not found")
alexander3c798932021-03-26 21:42:19 +0000117
118 # Cpp filename:
Richard Burton17069622022-03-17 10:54:26 +0000119 cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
120 print(f"++ Converting {Path(args.tflite_path).name} to\
121 {cpp_filename.name}")
alexander3c798932021-03-26 21:42:19 +0000122
Richard Burton17069622022-03-17 10:54:26 +0000123 cpp_filename.parent.mkdir(exist_ok=True)
alexander3c798932021-03-26 21:42:19 +0000124
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100125 hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name)
alexander3c798932021-03-26 21:42:19 +0000126
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100127 env \
128 .get_template('tflite.cc.template') \
129 .stream(common_template_header=hdr,
130 model_data=get_tflite_data(args.tflite_path),
131 expressions=args.expr,
132 additional_headers=args.headers,
133 namespaces=args.namespaces).dump(str(cpp_filename))
alexander3c798932021-03-26 21:42:19 +0000134
135
136if __name__ == '__main__':
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100137 main(parsed_args)