blob: b166ed77be2c45c0d9065a8b86dac7907902b7f1 [file] [log] [blame]
Pavel Macenauer5e123f82020-04-15 13:28:29 +00001"""Downloads and extracts resources for unit tests.
2
3It is mandatory to run this script prior to running unit tests. Resources are stored as a tar.gz or a tar.bz2 archive and
4extracted into the test/testdata/shared folder.
5"""
6
7import tarfile
8import requests
9import os
10import uuid
11
12SCRIPTS_DIR = os.path.dirname(os.path.realpath(__file__))
13EXTRACT_DIR = os.path.join(SCRIPTS_DIR, "..", "test")
14ARCHIVE_URL = "https://snapshots.linaro.org/components/pyarmnn-tests/pyarmnn_testdata_200500_20200415.tar.bz2"
15
16
17def download_resources(url, save_path):
18 # download archive - only support tar.gz or tar.bz2
19 print("Downloading '{}'".format(url))
20 temp_filename = str(uuid.uuid4())
21 if url.endswith(".tar.bz2"):
22 temp_filename += ".tar.bz2"
23 elif url.endswith(".tar.gz"):
24 temp_filename += ".tar.gz"
25 else:
26 raise RuntimeError("Unsupported file.")
27 try:
28 r = requests.get(url, stream=True)
29 except requests.exceptions.RequestException as e:
30 raise RuntimeError("Unable to download file: {}".format(e))
31 file_path = os.path.join(save_path, temp_filename)
32 with open(file_path, 'wb') as f:
33 f.write(r.content)
34
35 # extract and delete temp file
36 with tarfile.open(file_path, "r:bz2" if temp_filename.endswith(".tar.bz2") else "r:gz") as tar:
37 print("Extracting '{}'".format(file_path))
38 tar.extractall(save_path)
39 if os.path.exists(file_path):
40 print("Removing '{}'".format(file_path))
41 os.remove(file_path)
42
43
44download_resources(ARCHIVE_URL, EXTRACT_DIR)