blob: b2c58b0474a47ef29a414056dbf8bda1e4ffdf7a [file] [log] [blame]
Pavel Macenauer59e057f2020-04-15 14:17:26 +00001#!/usr/bin/env python3
2# Copyright 2020 NXP
3# SPDX-License-Identifier: MIT
Pavel Macenauer5e123f82020-04-15 13:28:29 +00004"""Downloads and extracts resources for unit tests.
5
6It is mandatory to run this script prior to running unit tests. Resources are stored as a tar.gz or a tar.bz2 archive and
7extracted into the test/testdata/shared folder.
8"""
9
10import tarfile
11import requests
12import os
13import uuid
14
15SCRIPTS_DIR = os.path.dirname(os.path.realpath(__file__))
16EXTRACT_DIR = os.path.join(SCRIPTS_DIR, "..", "test")
Pavel Macenauer1be893f2020-10-22 08:07:59 +000017ARCHIVE_URL = "https://snapshots.linaro.org/components/pyarmnn-tests/pyarmnn_testdata_201100_20201022.tar.bz2"
Pavel Macenauer5e123f82020-04-15 13:28:29 +000018
19
20def download_resources(url, save_path):
21 # download archive - only support tar.gz or tar.bz2
22 print("Downloading '{}'".format(url))
23 temp_filename = str(uuid.uuid4())
24 if url.endswith(".tar.bz2"):
25 temp_filename += ".tar.bz2"
26 elif url.endswith(".tar.gz"):
27 temp_filename += ".tar.gz"
28 else:
29 raise RuntimeError("Unsupported file.")
30 try:
31 r = requests.get(url, stream=True)
32 except requests.exceptions.RequestException as e:
33 raise RuntimeError("Unable to download file: {}".format(e))
34 file_path = os.path.join(save_path, temp_filename)
35 with open(file_path, 'wb') as f:
36 f.write(r.content)
37
38 # extract and delete temp file
39 with tarfile.open(file_path, "r:bz2" if temp_filename.endswith(".tar.bz2") else "r:gz") as tar:
40 print("Extracting '{}'".format(file_path))
41 tar.extractall(save_path)
42 if os.path.exists(file_path):
43 print("Removing '{}'".format(file_path))
44 os.remove(file_path)
45
46
47download_resources(ARCHIVE_URL, EXTRACT_DIR)