blob: 18aa9afd8d395105e09f9eea7b35e9455ccbaab8 [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
2# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3# SPDX-License-Identifier: Apache-2.0
4#
5import os
6from pathlib import Path
7from typing import List
8from urllib.request import urlopen
9"""
10Downloads resources for tests from Arm public model zoo.
11Run this script before executing tests.
12"""
13
14
15PMZ_URL = 'https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models'
16test_resources = [
17 {'model': '{}/visual_wake_words/micronet_vww2/tflite_int8/vww2_50_50_INT8.tflite'.format(PMZ_URL),
18 'ifm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_input/input/0.npy'.format(PMZ_URL),
19 'ofm': '{}/visual_wake_words/micronet_vww2/tflite_int8/testing_output/Identity/0.npy'.format(PMZ_URL)}
20]
21
22
23def download(path: str, url: str):
24 with urlopen(url) as response, open(path, 'wb') as file:
25 print("Downloading {} ...".format(url))
26 file.write(response.read())
27 file.seek(0)
28 print("Finished downloading {}.".format(url))
29
30
31def download_test_resources(test_res_entries: List[dict], where_to: str):
32 os.makedirs(where_to, exist_ok=True)
33
34 for resources in test_res_entries:
35 download(os.path.join(where_to, 'model.tflite'), resources['model'])
36 download(os.path.join(where_to, 'model_ifm.npy'), resources['ifm'])
37 download(os.path.join(where_to, 'model_ofm.npy'), resources['ofm'])
38
39
40def main():
41 current_dir = str(Path(__file__).parent.absolute())
42 download_test_resources(test_resources, os.path.join(current_dir, 'shared'))
43
44
45if __name__ == '__main__':
46 main()