Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 1 | # |
| 2 | # SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | import os |
| 6 | from pathlib import Path |
| 7 | from typing import List |
| 8 | from urllib.request import urlopen |
| 9 | """ |
| 10 | Downloads resources for tests from Arm public model zoo. |
| 11 | Run this script before executing tests. |
| 12 | """ |
| 13 | |
| 14 | |
| 15 | PMZ_URL = 'https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models' |
| 16 | test_resources = [ |
| 17 | {'model': '{}/visual_wake_words/micronet_vww2/tflite_int8/vww2_50_50_INT8.tflite'.format(PMZ_URL), |
| 18 | 'ifm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_input/input/0.npy'.format(PMZ_URL), |
| 19 | 'ofm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_output/Identity/0.npy'.format(PMZ_URL)} |
| 20 | ] |
| 21 | |
| 22 | |
| 23 | def download(path: str, url: str): |
| 24 | with urlopen(url) as response, open(path, 'wb') as file: |
| 25 | print("Downloading {} ...".format(url)) |
| 26 | file.write(response.read()) |
| 27 | file.seek(0) |
| 28 | print("Finished downloading {}.".format(url)) |
| 29 | |
| 30 | |
| 31 | def download_test_resources(test_res_entries: List[dict], where_to: str): |
| 32 | os.makedirs(where_to, exist_ok=True) |
| 33 | |
| 34 | for resources in test_res_entries: |
| 35 | download(os.path.join(where_to, 'model.tflite'), resources['model']) |
| 36 | download(os.path.join(where_to, 'model_ifm.npy'), resources['ifm']) |
| 37 | download(os.path.join(where_to, 'model_ofm.npy'), resources['ofm']) |
| 38 | |
| 39 | |
| 40 | def main(): |
| 41 | current_dir = str(Path(__file__).parent.absolute()) |
| 42 | download_test_resources(test_resources, os.path.join(current_dir, 'shared')) |
| 43 | |
| 44 | |
| 45 | if __name__ == '__main__': |
| 46 | main() |