blob: 484366866602bbd8a2db885e0f49812a698fe26b [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001# Copyright (c) 2021 Arm Limited. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Utility script to generate model c file that can be included in the
18project directly. This should be called as part of cmake framework
19should the models need to be generated at configuration stage.
20"""
21import datetime
22import os
23from argparse import ArgumentParser
24from pathlib import Path
25from jinja2 import Environment, FileSystemLoader
26
27parser = ArgumentParser()
28
29parser.add_argument("--tflite_path", help="Model (.tflite) path", required=True)
30parser.add_argument("--output_dir", help="Output directory", required=True)
31parser.add_argument('-e', '--expression', action='append', default=[], dest="expr")
32parser.add_argument('--header', action='append', default=[], dest="headers")
33parser.add_argument('-ns', '--namespaces', action='append', default=[], dest="namespaces")
34parser.add_argument("--license_template", type=str, help="Header template file",
35 default="header_template.txt")
36args = parser.parse_args()
37
38env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'templates')),
39 trim_blocks=True,
40 lstrip_blocks=True)
41
42
43def write_tflite_data(tflite_path):
44 # Extract array elements
45
46 bytes = model_hex_bytes(tflite_path)
47 line = '{\n'
48 i = 1
49 while True:
50 try:
51 el = next(bytes)
52 line = line + el + ', '
53 if i % 20 == 0:
54 yield line
55 line = ''
56 i += 1
57 except StopIteration:
58 line = line[:-2] + '};\n'
59 yield line
60 break
61
62
63def model_hex_bytes(tflite_path):
64 with open(tflite_path, 'rb') as tflite_model:
65 byte = tflite_model.read(1)
66 while byte != b"":
67 yield f'0x{byte.hex()}'
68 byte = tflite_model.read(1)
69
70
71def main(args):
72 if not os.path.isfile(args.tflite_path):
73 raise Exception(f"{args.tflite_path} not found")
74
75 # Cpp filename:
76 cpp_filename = Path(os.path.join(args.output_dir, os.path.basename(args.tflite_path) + ".cc")).absolute()
77 print(f"++ Converting {os.path.basename(args.tflite_path)} to\
78 {os.path.basename(cpp_filename)}")
79
80 os.makedirs(cpp_filename.parent, exist_ok=True)
81
82 header_template = env.get_template(args.license_template)
83
84 hdr = header_template.render(script_name=os.path.basename(__file__),
85 file_name=os.path.basename(args.tflite_path),
86 gen_time=datetime.datetime.now(),
87 year=datetime.datetime.now().year)
88
89 env.get_template('tflite.cc.template').stream(common_template_header=hdr,
90 model_data=write_tflite_data(args.tflite_path),
91 expressions=args.expr,
92 additional_headers=args.headers,
93 namespaces=args.namespaces).dump(str(cpp_filename))
94
95
96if __name__ == '__main__':
97 main(args)