blob: e4933b5bb81e74a6982f40a672bd7eff8f9f499d [file] [log] [blame]
Richard Burtonf32a86a2022-11-15 11:46:11 +00001# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Utility script to generate model c file that can be included in the
18project directly. This should be called as part of cmake framework
19should the models need to be generated at configuration stage.
20"""
21import datetime
alexander3c798932021-03-26 21:42:19 +000022from argparse import ArgumentParser
23from pathlib import Path
Richard Burton17069622022-03-17 10:54:26 +000024
alexander3c798932021-03-26 21:42:19 +000025from jinja2 import Environment, FileSystemLoader
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010026import binascii
alexander3c798932021-03-26 21:42:19 +000027
28parser = ArgumentParser()
29
30parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
31parser.add_argument("--output_dir", help="Output directory", required=True)
32parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
33parser.add_argument('--header', action='append', default=[], dest="headers")
34parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
35parser.add_argument("--license_template", type=str, help="Header template file",
36 default="header_template.txt")
37args = parser.parse_args()
38
Richard Burton17069622022-03-17 10:54:26 +000039env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'),
alexander3c798932021-03-26 21:42:19 +000040 trim_blocks=True,
41 lstrip_blocks=True)
42
43
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010044def get_tflite_data(tflite_path: str) -> list:
45 """
46 Reads a binary file and returns a C style array as a
47 list of strings.
alexander3c798932021-03-26 21:42:19 +000048
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010049 Argument:
50 tflite_path: path to the tflite model.
alexander3c798932021-03-26 21:42:19 +000051
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010052 Returns:
53 list of strings
54 """
alexander3c798932021-03-26 21:42:19 +000055 with open(tflite_path, 'rb') as tflite_model:
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010056 data = tflite_model.read()
57
58 bytes_per_line = 32
59 hex_digits_per_line = bytes_per_line * 2
60 hexstream = binascii.hexlify(data).decode('utf-8')
61 hexstring = '{'
62
63 for i in range(0, len(hexstream), 2):
64 if 0 == (i % hex_digits_per_line):
65 hexstring += "\n"
66 hexstring += '0x' + hexstream[i:i+2] + ", "
67
68 hexstring += '};\n'
69 return [hexstring]
alexander3c798932021-03-26 21:42:19 +000070
71
72def main(args):
Richard Burton17069622022-03-17 10:54:26 +000073 if not Path(args.tflite_path).is_file():
alexander3c798932021-03-26 21:42:19 +000074 raise Exception(f"{args.tflite_path} not found")
75
76 # Cpp filename:
Richard Burton17069622022-03-17 10:54:26 +000077 cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve()
78 print(f"++ Converting {Path(args.tflite_path).name} to\
79 {cpp_filename.name}")
alexander3c798932021-03-26 21:42:19 +000080
Richard Burton17069622022-03-17 10:54:26 +000081 cpp_filename.parent.mkdir(exist_ok=True)
alexander3c798932021-03-26 21:42:19 +000082
83 header_template = env.get_template(args.license_template)
84
Richard Burton17069622022-03-17 10:54:26 +000085 hdr = header_template.render(script_name=Path(__file__).name,
86 file_name=Path(args.tflite_path).name,
alexander3c798932021-03-26 21:42:19 +000087 gen_time=datetime.datetime.now(),
88 year=datetime.datetime.now().year)
89
90 env.get_template('tflite.cc.template').stream(common_template_header=hdr,
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010091 model_data=get_tflite_data(args.tflite_path),
alexander3c798932021-03-26 21:42:19 +000092 expressions=args.expr,
93 additional_headers=args.headers,
94 namespaces=args.namespaces).dump(str(cpp_filename))
95
96
97if __name__ == '__main__':
98 main(args)