blob: c43c93a4330c24051a9df8978bbe06d9ca189916 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001# Copyright (c) 2021 Arm Limited. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Utility script to generate model c file that can be included in the
18project directly. This should be called as part of cmake framework
19should the models need to be generated at configuration stage.
20"""
21import datetime
22import os
23from argparse import ArgumentParser
24from pathlib import Path
25from jinja2 import Environment, FileSystemLoader
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010026import binascii
alexander3c798932021-03-26 21:42:19 +000027
28parser = ArgumentParser()
29
30parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
31parser.add_argument("--output_dir", help="Output directory", required=True)
32parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
33parser.add_argument('--header', action='append', default=[], dest="headers")
34parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
35parser.add_argument("--license_template", type=str, help="Header template file",
36 default="header_template.txt")
37args = parser.parse_args()
38
39env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates')),
40 trim_blocks=True,
41 lstrip_blocks=True)
42
43
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010044def get_tflite_data(tflite_path: str) -> list:
45 """
46 Reads a binary file and returns a C style array as a
47 list of strings.
alexander3c798932021-03-26 21:42:19 +000048
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010049 Argument:
50 tflite_path: path to the tflite model.
alexander3c798932021-03-26 21:42:19 +000051
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010052 Returns:
53 list of strings
54 """
alexander3c798932021-03-26 21:42:19 +000055 with open(tflite_path, 'rb') as tflite_model:
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010056 data = tflite_model.read()
57
58 bytes_per_line = 32
59 hex_digits_per_line = bytes_per_line * 2
60 hexstream = binascii.hexlify(data).decode('utf-8')
61 hexstring = '{'
62
63 for i in range(0, len(hexstream), 2):
64 if 0 == (i % hex_digits_per_line):
65 hexstring += "\n"
66 hexstring += '0x' + hexstream[i:i+2] + ", "
67
68 hexstring += '};\n'
69 return [hexstring]
alexander3c798932021-03-26 21:42:19 +000070
71
72def main(args):
73 if not os.path.isfile(args.tflite_path):
74 raise Exception(f"{args.tflite_path} not found")
75
76 # Cpp filename:
77 cpp_filename = Path(os.path.join(args.output_dir, os.path.basename(args.tflite_path) + ".cc")).absolute()
78 print(f"++ Converting {os.path.basename(args.tflite_path)} to\
79 {os.path.basename(cpp_filename)}")
80
81 os.makedirs(cpp_filename.parent, exist_ok=True)
82
83 header_template = env.get_template(args.license_template)
84
85 hdr = header_template.render(script_name=os.path.basename(__file__),
86 file_name=os.path.basename(args.tflite_path),
87 gen_time=datetime.datetime.now(),
88 year=datetime.datetime.now().year)
89
90 env.get_template('tflite.cc.template').stream(common_template_header=hdr,
Kshitij Sisodia1da52ae2021-06-25 09:55:14 +010091 model_data=get_tflite_data(args.tflite_path),
alexander3c798932021-03-26 21:42:19 +000092 expressions=args.expr,
93 additional_headers=args.headers,
94 namespaces=args.namespaces).dump(str(cpp_filename))
95
96
97if __name__ == '__main__':
98 main(args)