blob: 60c174786d750cd42b842a880b36899945233426 [file] [log] [blame]
Isabella Gottardi2181d0a2021-04-07 09:27:38 +01001#!env/bin/python3
2
3# Copyright (c) 2021 Arm Limited. All rights reserved.
4# SPDX-License-Identifier: Apache-2.0
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import os, errno
19import urllib.request
20import subprocess
21import fnmatch
22import logging
23import sys
24
25from argparse import ArgumentParser
26from urllib.error import URLError
27
28json_uc_res = [{
29 "use_case_name": "ad",
30 "resources": [{"name": "ad_medium_int8.tflite",
31 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/ad_medium_int8.tflite"},
32 {"name": "ifm0.npy",
33 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/testing_input/input/0.npy"},
34 {"name": "ofm0.npy",
35 "url": "https://github.com/ARM-software/ML-zoo/raw/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/testing_output/Identity/0.npy"}]
36},
37 {
38 "use_case_name": "asr",
39 "resources": [{"name": "wav2letter_int8.tflite",
40 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/wav2letter_int8.tflite"},
41 {"name": "ifm0.npy",
42 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/testing_input/input_2_int8/0.npy"},
43 {"name": "ofm0.npy",
44 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/testing_output/Identity_int8/0.npy"}]
45 },
46 {
47 "use_case_name": "img_class",
48 "resources": [{"name": "mobilenet_v2_1.0_224_quantized_1_default_1.tflite",
49 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8/mobilenet_v2_1.0_224_quantized_1_default_1.tflite"},
50 {"name": "ifm0.npy",
51 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8/testing_input/input/0.npy"},
52 {"name": "ofm0.npy",
53 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8/testing_output/output/0.npy"}]
54 },
55 {
56 "use_case_name": "kws",
57 "resources": [{"name": "ds_cnn_clustered_int8.tflite",
58 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"},
59 {"name": "ifm0.npy",
60 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"},
61 {"name": "ofm0.npy",
62 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}]
63 },
64 {
65 "use_case_name": "kws_asr",
66 "resources": [{"name": "wav2letter_int8.tflite",
67 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/wav2letter_int8.tflite"},
68 {"sub_folder": "asr", "name": "ifm0.npy",
69 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/testing_input/input_2_int8/0.npy"},
70 {"sub_folder": "asr", "name": "ofm0.npy",
71 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/speech_recognition/wav2letter/tflite_int8/testing_output/Identity_int8/0.npy"},
72 {"name": "ds_cnn_clustered_int8.tflite",
73 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"},
74 {"sub_folder": "kws", "name": "ifm0.npy",
75 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"},
76 {"sub_folder": "kws", "name": "ofm0.npy",
77 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}]
78 },
79 {
80 "use_case_name": "inference_runner",
81 "resources": [{"name": "dnn_s_quantized.tflite",
82 "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/dnn_small/tflite_int8/dnn_s_quantized.tflite"}
83 ]
84 },]
85
86
87def call_command(command: str) -> str:
88 """
89 Helpers function that call subprocess and return the output.
90
91 Parameters:
92 ----------
93 command (string): Specifies the command to run.
94 """
95 logging.info(command)
96 proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
97 stdout_log = proc.communicate()[0].decode("utf-8")
98 logging.info(stdout_log)
99 return stdout_log
100
101
102def set_up_resources(run_vela_on_models=False):
103 """
104 Helpers function that retrieve the output from a command.
105
106 Parameters:
107 ----------
108 run_vela_on_models (bool): Specifies if run vela on downloaded models.
109 """
110 current_file_dir = os.path.dirname(os.path.abspath(__file__))
111 download_dir = os.path.abspath(os.path.join(current_file_dir, "resources_downloaded"))
112 logging.basicConfig(filename='log_build_default.log', level=logging.DEBUG)
113 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
114
115 try:
116 # 1.1 Does the download dir exist?
117 os.mkdir(download_dir)
118 except OSError as e:
119 if e.errno == errno.EEXIST:
120 logging.info("'resources_downloaded' directory exists.")
121 else:
122 raise
123
124 # 1.2 Does the virtual environment exist?
125 env_python = str(os.path.abspath(os.path.join(download_dir, "env", "bin", "python3")))
126 env_activate = str(os.path.abspath(os.path.join(download_dir, "env", "bin", "activate")))
127 if not os.path.isdir(os.path.join(download_dir, "env")):
128 os.chdir(download_dir)
129 # Create the virtual environment
130 command = "python3 -m venv env"
131 call_command(command)
132 commands = ["pip install --upgrade pip", "pip install --upgrade setuptools"]
133 for c in commands:
134 command = f"{env_python} -m {c}"
135 call_command(command)
136 os.chdir(current_file_dir)
137 # 1.3 Make sure to have all the requirement
138 requirements = ["ethos-u-vela==2.1.1"]
139 command = f"{env_python} -m pip freeze"
140 packages = call_command(command)
141 for req in requirements:
142 if req not in packages:
143 command = f"{env_python} -m pip install {req}"
144 call_command(command)
145
146 # 2. Download models
147 for uc in json_uc_res:
148 try:
149 # Does the usecase_name download dir exist?
150 os.mkdir(os.path.join(download_dir, uc["use_case_name"]))
151 except OSError as e:
152 if e.errno != errno.EEXIST:
153 logging.error(f"Error creating {uc['use_case_name']} directory.")
154 raise
155
156 for res in uc["resources"]:
157 res_name = res["name"]
158 res_url = res["url"]
159 if "sub_folder" in res:
160 try:
161 # Does the usecase_name/sub_folder download dir exist?
162 os.mkdir(os.path.join(download_dir, uc["use_case_name"], res["sub_folder"]))
163 except OSError as e:
164 if e.errno != errno.EEXIST:
165 logging.error(f"Error creating {uc['use_case_name']} / {res['sub_folder']} directory.")
166 raise
167 res_dst = os.path.join(download_dir,
168 uc["use_case_name"],
169 res["sub_folder"],
170 res_name)
171 else:
172 res_dst = os.path.join(download_dir,
173 uc["use_case_name"],
174 res_name)
175 try:
176 g = urllib.request.urlopen(res_url)
177 with open(res_dst, 'b+w') as f:
178 f.write(g.read())
179 logging.info(f"- Downloaded {res_url} to {res_dst}.")
180 except URLError:
181 logging.error(f"URLError while downloading {res_url}.")
182 raise
183
184 # 3. Run vela on models in resources_downloaded
185 # New models will have same name with '_vela' appended.
186 # For example:
187 # original model: ds_cnn_clustered_int8.tflite
188 # after vela model: ds_cnn_clustered_int8_vela.tflite
189 #
190 # Note: To avoid to run vela twice on the same model, it's supposed that
191 # downloaded model names don't contain the 'vela' word.
192 if run_vela_on_models is True:
193 config_file = os.path.join(current_file_dir, "scripts", "vela", "default_vela.ini")
194 models = [os.path.join(dirpath, f)
195 for dirpath, dirnames, files in os.walk(download_dir)
196 for f in fnmatch.filter(files, '*.tflite') if "vela" not in f]
197
198 for model in models:
199 output_dir = os.path.dirname(model)
200 command = (f". {env_activate} && vela {model} " +
201 "--accelerator-config=ethos-u55-128 " +
202 "--block-config-limit=0 " +
203 f"--config {config_file} " +
204 "--memory-mode=Shared_Sram " +
205 "--system-config=Ethos_U55_High_End_Embedded " +
206 f"--output-dir={output_dir}")
207 call_command(command)
208
209
210if __name__ == '__main__':
211 parser = ArgumentParser()
212 parser.add_argument("--skip-vela",
213 help="Do not run Vela optimizer on downloaded models.",
214 action="store_true")
215 args = parser.parse_args()
216 set_up_resources(not args.skip_vela)