Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | # |
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | # you may not use this file except in compliance with the License. |
| 6 | # You may obtain a copy of the License at |
| 7 | # |
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | # |
| 10 | # Unless required by applicable law or agreed to in writing, software |
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | # See the License for the specific language governing permissions and |
| 14 | # limitations under the License. |
| 15 | |
| 16 | """ |
| 17 | Utility script to generate model c file that can be included in the |
| 18 | project directly. This should be called as part of cmake framework |
| 19 | should the models need to be generated at configuration stage. |
| 20 | """ |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 21 | import binascii |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 22 | from argparse import ArgumentParser |
| 23 | from pathlib import Path |
Richard Burton | 1706962 | 2022-03-17 10:54:26 +0000 | [diff] [blame] | 24 | |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 25 | from jinja2 import Environment, FileSystemLoader |
| 26 | |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 27 | from gen_utils import GenUtils |
| 28 | |
| 29 | # pylint: disable=duplicate-code |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 30 | parser = ArgumentParser() |
| 31 | |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 32 | parser.add_argument( |
| 33 | "--tflite_path", |
| 34 | help="Model (.tflite) path", |
| 35 | required=True |
| 36 | ) |
| 37 | |
| 38 | parser.add_argument( |
| 39 | "--output_dir", |
| 40 | help="Output directory", |
| 41 | required=True |
| 42 | ) |
| 43 | |
| 44 | parser.add_argument( |
| 45 | '-e', |
| 46 | '--expression', |
| 47 | action='append', |
| 48 | default=[], |
| 49 | dest="expr" |
| 50 | ) |
| 51 | |
| 52 | parser.add_argument( |
| 53 | '--header', |
| 54 | action='append', |
| 55 | default=[], |
| 56 | dest="headers" |
| 57 | ) |
| 58 | |
| 59 | parser.add_argument( |
| 60 | '-ns', |
| 61 | '--namespaces', |
| 62 | action='append', |
| 63 | default=[], |
| 64 | dest="namespaces" |
| 65 | ) |
| 66 | |
| 67 | parser.add_argument( |
| 68 | "--license_template", |
| 69 | type=str, |
| 70 | help="Header template file", |
| 71 | default="header_template.txt" |
| 72 | ) |
| 73 | |
| 74 | parsed_args = parser.parse_args() |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 75 | |
Richard Burton | 1706962 | 2022-03-17 10:54:26 +0000 | [diff] [blame] | 76 | env = Environment(loader=FileSystemLoader(Path(__file__).parent / 'templates'), |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 77 | trim_blocks=True, |
| 78 | lstrip_blocks=True) |
| 79 | |
| 80 | |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 81 | # pylint: enable=duplicate-code |
Kshitij Sisodia | 1da52ae | 2021-06-25 09:55:14 +0100 | [diff] [blame] | 82 | def get_tflite_data(tflite_path: str) -> list: |
| 83 | """ |
| 84 | Reads a binary file and returns a C style array as a |
| 85 | list of strings. |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 86 | |
Kshitij Sisodia | 1da52ae | 2021-06-25 09:55:14 +0100 | [diff] [blame] | 87 | Argument: |
| 88 | tflite_path: path to the tflite model. |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 89 | |
Kshitij Sisodia | 1da52ae | 2021-06-25 09:55:14 +0100 | [diff] [blame] | 90 | Returns: |
| 91 | list of strings |
| 92 | """ |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 93 | with open(tflite_path, 'rb') as tflite_model: |
Kshitij Sisodia | 1da52ae | 2021-06-25 09:55:14 +0100 | [diff] [blame] | 94 | data = tflite_model.read() |
| 95 | |
| 96 | bytes_per_line = 32 |
| 97 | hex_digits_per_line = bytes_per_line * 2 |
| 98 | hexstream = binascii.hexlify(data).decode('utf-8') |
| 99 | hexstring = '{' |
| 100 | |
| 101 | for i in range(0, len(hexstream), 2): |
| 102 | if 0 == (i % hex_digits_per_line): |
| 103 | hexstring += "\n" |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 104 | hexstring += '0x' + hexstream[i:i + 2] + ", " |
Kshitij Sisodia | 1da52ae | 2021-06-25 09:55:14 +0100 | [diff] [blame] | 105 | |
| 106 | hexstring += '};\n' |
| 107 | return [hexstring] |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 108 | |
| 109 | |
| 110 | def main(args): |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 111 | """ |
| 112 | Generate models .cpp |
| 113 | @param args: Parsed args |
| 114 | """ |
Richard Burton | 1706962 | 2022-03-17 10:54:26 +0000 | [diff] [blame] | 115 | if not Path(args.tflite_path).is_file(): |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 116 | raise ValueError(f"{args.tflite_path} not found") |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 117 | |
| 118 | # Cpp filename: |
Richard Burton | 1706962 | 2022-03-17 10:54:26 +0000 | [diff] [blame] | 119 | cpp_filename = (Path(args.output_dir) / (Path(args.tflite_path).name + ".cc")).resolve() |
| 120 | print(f"++ Converting {Path(args.tflite_path).name} to\ |
| 121 | {cpp_filename.name}") |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 122 | |
Richard Burton | 1706962 | 2022-03-17 10:54:26 +0000 | [diff] [blame] | 123 | cpp_filename.parent.mkdir(exist_ok=True) |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 124 | |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 125 | hdr = GenUtils.gen_header(env, args.license_template, Path(args.tflite_path).name) |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 126 | |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 127 | env \ |
| 128 | .get_template('tflite.cc.template') \ |
| 129 | .stream(common_template_header=hdr, |
| 130 | model_data=get_tflite_data(args.tflite_path), |
| 131 | expressions=args.expr, |
| 132 | additional_headers=args.headers, |
| 133 | namespaces=args.namespaces).dump(str(cpp_filename)) |
alexander | 3c79893 | 2021-03-26 21:42:19 +0000 | [diff] [blame] | 134 | |
| 135 | |
| 136 | if __name__ == '__main__': |
Alex Tawse | daba3cf | 2023-09-29 15:55:38 +0100 | [diff] [blame] | 137 | main(parsed_args) |