blob: 219cb3c69b49822bc25bcaa97bdaf7b4c8aed80a [file] [log] [blame]
alexanderf4e2c472021-05-14 13:14:21 +01001#!/usr/bin/env python3
Isabella Gottardi2181d0a2021-04-07 09:27:38 +01002
3# Copyright (c) 2021 Arm Limited. All rights reserved.
4# SPDX-License-Identifier: Apache-2.0
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import os, errno
19import urllib.request
20import subprocess
21import fnmatch
22import logging
23import sys
24
25from argparse import ArgumentParser
26from urllib.error import URLError
27
28json_uc_res = [{
29 "use_case_name": "ad",
30 "resources": [{"name": "ad_medium_int8.tflite",
31 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/ad_medium_int8.tflite"},
32 {"name": "ifm0.npy",
33 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/testing_input/input/0.npy"},
34 {"name": "ofm0.npy",
35 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/testing_output/Identity/0.npy"}]
36},
37 {
38 "use_case_name": "asr",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010039 "resources": [{"name": "wav2letter_pruned_int8.tflite",
40 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010041 {"name": "ifm0.npy",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010042 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010043 {"name": "ofm0.npy",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010044 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy"}]
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010045 },
46 {
47 "use_case_name": "img_class",
Richard Burton0d110592021-08-12 17:26:30 +010048 "resources": [{"name": "mobilenet_v2_1.0_224_INT8.tflite",
49 "url": "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/mobilenet_v2_1.0_224_INT8.tflite"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010050 {"name": "ifm0.npy",
Richard Burton0d110592021-08-12 17:26:30 +010051 "url": "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/testing_input/tfl.quantize/0.npy"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010052 {"name": "ofm0.npy",
Richard Burton0d110592021-08-12 17:26:30 +010053 "url": "https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/testing_output/MobilenetV2/Predictions/Reshape_11/0.npy"}]
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010054 },
55 {
56 "use_case_name": "kws",
57 "resources": [{"name": "ds_cnn_clustered_int8.tflite",
58 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"},
59 {"name": "ifm0.npy",
60 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"},
61 {"name": "ofm0.npy",
62 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}]
63 },
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010064 {
65 "use_case_name": "vww",
66 "resources": [{"name": "vww4_128_128_INT8.tflite",
67 "url": "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/vww4_128_128_INT8.tflite"},
68 {"name": "ifm0.npy",
69 "url": "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/testing_input/input/0.npy"},
70 {"name": "ofm0.npy",
71 "url": "https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/testing_output/Identity/0.npy"}]
72 },
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010073 {
74 "use_case_name": "kws_asr",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010075 "resources": [{"name": "wav2letter_pruned_int8.tflite",
76 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010077 {"sub_folder": "asr", "name": "ifm0.npy",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010078 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010079 {"sub_folder": "asr", "name": "ofm0.npy",
Kshitij Sisodiae12ac832021-05-20 11:18:53 +010080 "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy"},
Isabella Gottardi2181d0a2021-04-07 09:27:38 +010081 {"name": "ds_cnn_clustered_int8.tflite",
82 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"},
83 {"sub_folder": "kws", "name": "ifm0.npy",
84 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"},
85 {"sub_folder": "kws", "name": "ofm0.npy",
86 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}]
87 },
88 {
89 "use_case_name": "inference_runner",
90 "resources": [{"name": "dnn_s_quantized.tflite",
91 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/dnn_s_quantized.tflite"}
92 ]
93 },]
94
95
96def call_command(command: str) -> str:
97 """
98 Helpers function that call subprocess and return the output.
99
100 Parameters:
101 ----------
102 command (string): Specifies the command to run.
103 """
104 logging.info(command)
alexander50a06502021-05-12 19:06:02 +0100105 proc = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
106 log = proc.stdout.decode("utf-8")
107 if proc.returncode == 0:
108 logging.info(log)
109 else:
110 logging.error(log)
111 proc.check_returncode()
112 return log
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100113
114
115def set_up_resources(run_vela_on_models=False):
116 """
117 Helpers function that retrieve the output from a command.
118
119 Parameters:
120 ----------
121 run_vela_on_models (bool): Specifies if run vela on downloaded models.
122 """
123 current_file_dir = os.path.dirname(os.path.abspath(__file__))
124 download_dir = os.path.abspath(os.path.join(current_file_dir, "resources_downloaded"))
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100125
126 try:
127 # 1.1 Does the download dir exist?
128 os.mkdir(download_dir)
129 except OSError as e:
130 if e.errno == errno.EEXIST:
131 logging.info("'resources_downloaded' directory exists.")
132 else:
133 raise
134
135 # 1.2 Does the virtual environment exist?
136 env_python = str(os.path.abspath(os.path.join(download_dir, "env", "bin", "python3")))
137 env_activate = str(os.path.abspath(os.path.join(download_dir, "env", "bin", "activate")))
138 if not os.path.isdir(os.path.join(download_dir, "env")):
139 os.chdir(download_dir)
140 # Create the virtual environment
141 command = "python3 -m venv env"
142 call_command(command)
143 commands = ["pip install --upgrade pip", "pip install --upgrade setuptools"]
144 for c in commands:
145 command = f"{env_python} -m {c}"
146 call_command(command)
147 os.chdir(current_file_dir)
148 # 1.3 Make sure to have all the requirement
Nina Drozdf6753c92021-09-07 09:41:28 +0100149 requirements = ["ethos-u-vela==3.1.0"]
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100150 command = f"{env_python} -m pip freeze"
151 packages = call_command(command)
152 for req in requirements:
153 if req not in packages:
154 command = f"{env_python} -m pip install {req}"
155 call_command(command)
156
157 # 2. Download models
158 for uc in json_uc_res:
159 try:
160 # Does the usecase_name download dir exist?
161 os.mkdir(os.path.join(download_dir, uc["use_case_name"]))
162 except OSError as e:
163 if e.errno != errno.EEXIST:
164 logging.error(f"Error creating {uc['use_case_name']} directory.")
165 raise
166
167 for res in uc["resources"]:
168 res_name = res["name"]
169 res_url = res["url"]
170 if "sub_folder" in res:
171 try:
172 # Does the usecase_name/sub_folder download dir exist?
173 os.mkdir(os.path.join(download_dir, uc["use_case_name"], res["sub_folder"]))
174 except OSError as e:
175 if e.errno != errno.EEXIST:
176 logging.error(f"Error creating {uc['use_case_name']} / {res['sub_folder']} directory.")
177 raise
178 res_dst = os.path.join(download_dir,
179 uc["use_case_name"],
180 res["sub_folder"],
181 res_name)
182 else:
183 res_dst = os.path.join(download_dir,
184 uc["use_case_name"],
185 res_name)
alexander3ef1fd42021-05-24 18:56:32 +0100186
187 if os.path.isfile(res_dst):
188 logging.info(f"File {res_dst} exists, skipping download.")
189 else:
190 try:
191 g = urllib.request.urlopen(res_url)
192 with open(res_dst, 'b+w') as f:
193 f.write(g.read())
194 logging.info(f"- Downloaded {res_url} to {res_dst}.")
195 except URLError:
196 logging.error(f"URLError while downloading {res_url}.")
197 raise
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100198
199 # 3. Run vela on models in resources_downloaded
200 # New models will have same name with '_vela' appended.
201 # For example:
202 # original model: ds_cnn_clustered_int8.tflite
alexander50a06502021-05-12 19:06:02 +0100203 # after vela model: ds_cnn_clustered_int8_vela_H128.tflite
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100204 #
205 # Note: To avoid to run vela twice on the same model, it's supposed that
206 # downloaded model names don't contain the 'vela' word.
207 if run_vela_on_models is True:
208 config_file = os.path.join(current_file_dir, "scripts", "vela", "default_vela.ini")
209 models = [os.path.join(dirpath, f)
210 for dirpath, dirnames, files in os.walk(download_dir)
211 for f in fnmatch.filter(files, '*.tflite') if "vela" not in f]
212
213 for model in models:
214 output_dir = os.path.dirname(model)
alexander3ef1fd42021-05-24 18:56:32 +0100215 # model name after compiling with vela is an initial model name + _vela suffix
216 vela_optimised_model_path = str(model).replace(".tflite", "_vela.tflite")
217 # we want it to be initial model name + _vela_H128 suffix which indicates selected MAC config.
218 new_vela_optimised_model_path = vela_optimised_model_path.replace("_vela.tflite", "_vela_H128.tflite")
219
220 if os.path.isfile(new_vela_optimised_model_path):
221 logging.info(f"File {new_vela_optimised_model_path} exists, skipping optimisation.")
222 continue
223
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100224 command = (f". {env_activate} && vela {model} " +
225 "--accelerator-config=ethos-u55-128 " +
alexanderd475f092021-06-24 15:36:49 +0100226 "--optimise Performance " +
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100227 f"--config {config_file} " +
228 "--memory-mode=Shared_Sram " +
229 "--system-config=Ethos_U55_High_End_Embedded " +
230 f"--output-dir={output_dir}")
231 call_command(command)
alexander3ef1fd42021-05-24 18:56:32 +0100232
alexander50a06502021-05-12 19:06:02 +0100233 # rename default vela model
234 os.rename(vela_optimised_model_path, new_vela_optimised_model_path)
Isabella Gottardi958133d2021-05-07 11:57:30 +0100235 logging.info(f"Renaming {vela_optimised_model_path} to {new_vela_optimised_model_path}.")
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100236
237
238if __name__ == '__main__':
239 parser = ArgumentParser()
240 parser.add_argument("--skip-vela",
241 help="Do not run Vela optimizer on downloaded models.",
242 action="store_true")
243 args = parser.parse_args()
Kshitij Sisodiab9e9c892021-05-27 13:57:35 +0100244
245 logging.basicConfig(filename='log_build_default.log', level=logging.DEBUG)
246 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
247
Isabella Gottardi2181d0a2021-04-07 09:27:38 +0100248 set_up_resources(not args.skip_vela)