Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """Tests for json2numpy.py.""" |
| 2 | # Copyright (c) 2021-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import json |
| 5 | import os |
| 6 | |
| 7 | import numpy as np |
| 8 | import pytest |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 9 | from json2numpy.json2numpy import main |
| 10 | |
| 11 | |
| 12 | @pytest.mark.parametrize( |
| 13 | "npy_filename,json_filename,data_type", |
| 14 | [ |
| 15 | ("single_num.npy", "single_num.json", np.int8), |
| 16 | ("multiple_num.npy", "multiple_num.json", np.int8), |
| 17 | ("single_num.npy", "single_num.json", np.int16), |
| 18 | ("multiple_num.npy", "multiple_num.json", np.int16), |
| 19 | ("single_num.npy", "single_num.json", np.int32), |
| 20 | ("multiple_num.npy", "multiple_num.json", np.int32), |
| 21 | ("single_num.npy", "single_num.json", np.int64), |
| 22 | ("multiple_num.npy", "multiple_num.json", np.int64), |
| 23 | ("single_num.npy", "single_num.json", np.uint8), |
| 24 | ("multiple_num.npy", "multiple_num.json", np.uint8), |
| 25 | ("single_num.npy", "single_num.json", np.uint16), |
| 26 | ("multiple_num.npy", "multiple_num.json", np.uint16), |
| 27 | ("single_num.npy", "single_num.json", np.uint32), |
| 28 | ("multiple_num.npy", "multiple_num.json", np.uint32), |
| 29 | ("single_num.npy", "single_num.json", np.uint64), |
| 30 | ("multiple_num.npy", "multiple_num.json", np.uint64), |
| 31 | ("single_num.npy", "single_num.json", np.float16), |
| 32 | ("multiple_num.npy", "multiple_num.json", np.float16), |
| 33 | ("single_num.npy", "single_num.json", np.float32), |
| 34 | ("multiple_num.npy", "multiple_num.json", np.float32), |
| 35 | ("single_num.npy", "single_num.json", np.float64), |
| 36 | ("multiple_num.npy", "multiple_num.json", np.float64), |
| 37 | ("single_num.npy", "single_num.json", bool), |
| 38 | ("multiple_num.npy", "multiple_num.json", bool), |
| 39 | ], |
| 40 | ) |
| 41 | def test_json2numpy_npy_file(npy_filename, json_filename, data_type): |
| 42 | """Test conversion to JSON.""" |
| 43 | # Generate numpy data. |
| 44 | if "single" in npy_filename: |
| 45 | npy_data = np.ndarray(shape=(1, 1), dtype=data_type) |
| 46 | elif "multiple" in npy_filename: |
| 47 | npy_data = np.ndarray(shape=(2, 3), dtype=data_type) |
| 48 | |
| 49 | # Get filepaths |
| 50 | npy_file = os.path.join(os.path.dirname(__file__), npy_filename) |
| 51 | json_file = os.path.join(os.path.dirname(__file__), json_filename) |
| 52 | |
| 53 | # Save npy data to file and reload it. |
| 54 | with open(npy_file, "wb") as f: |
| 55 | np.save(f, npy_data) |
| 56 | npy_data = np.load(npy_file) |
| 57 | |
| 58 | args = [npy_file] |
| 59 | """Converts npy file to json""" |
| 60 | assert main(args) == 0 |
| 61 | |
| 62 | json_data = json.load(open(json_file)) |
| 63 | assert np.dtype(json_data["type"]) == npy_data.dtype |
| 64 | assert np.array(json_data["data"]).shape == npy_data.shape |
| 65 | assert (np.array(json_data["data"]) == npy_data).all() |
| 66 | |
| 67 | # Remove files created |
| 68 | if os.path.exists(npy_file): |
| 69 | os.remove(npy_file) |
| 70 | if os.path.exists(json_file): |
| 71 | os.remove(json_file) |
| 72 | |
| 73 | |
| 74 | @pytest.mark.parametrize( |
| 75 | "npy_filename,json_filename,data_type", |
| 76 | [ |
| 77 | ("single_num.npy", "single_num.json", np.int8), |
| 78 | ("multiple_num.npy", "multiple_num.json", np.int8), |
| 79 | ("single_num.npy", "single_num.json", np.int16), |
| 80 | ("multiple_num.npy", "multiple_num.json", np.int16), |
| 81 | ("single_num.npy", "single_num.json", np.int32), |
| 82 | ("multiple_num.npy", "multiple_num.json", np.int32), |
| 83 | ("single_num.npy", "single_num.json", np.int64), |
| 84 | ("multiple_num.npy", "multiple_num.json", np.int64), |
| 85 | ("single_num.npy", "single_num.json", np.uint8), |
| 86 | ("multiple_num.npy", "multiple_num.json", np.uint8), |
| 87 | ("single_num.npy", "single_num.json", np.uint16), |
| 88 | ("multiple_num.npy", "multiple_num.json", np.uint16), |
| 89 | ("single_num.npy", "single_num.json", np.uint32), |
| 90 | ("multiple_num.npy", "multiple_num.json", np.uint32), |
| 91 | ("single_num.npy", "single_num.json", np.uint64), |
| 92 | ("multiple_num.npy", "multiple_num.json", np.uint64), |
| 93 | ("single_num.npy", "single_num.json", np.float16), |
| 94 | ("multiple_num.npy", "multiple_num.json", np.float16), |
| 95 | ("single_num.npy", "single_num.json", np.float32), |
| 96 | ("multiple_num.npy", "multiple_num.json", np.float32), |
| 97 | ("single_num.npy", "single_num.json", np.float64), |
| 98 | ("multiple_num.npy", "multiple_num.json", np.float64), |
| 99 | ("single_num.npy", "single_num.json", bool), |
| 100 | ("multiple_num.npy", "multiple_num.json", bool), |
| 101 | ], |
| 102 | ) |
| 103 | def test_json2numpy_json_file(npy_filename, json_filename, data_type): |
| 104 | """Test conversion to binary.""" |
| 105 | # Generate json data. |
| 106 | if "single" in npy_filename: |
| 107 | npy_data = np.ndarray(shape=(1, 1), dtype=data_type) |
| 108 | elif "multiple" in npy_filename: |
| 109 | npy_data = np.ndarray(shape=(2, 3), dtype=data_type) |
| 110 | |
| 111 | # Generate json dictionary |
| 112 | list_data = npy_data.tolist() |
| 113 | json_data_type = str(npy_data.dtype) |
| 114 | |
| 115 | json_data = {} |
| 116 | json_data["type"] = json_data_type |
| 117 | json_data["data"] = list_data |
| 118 | |
| 119 | # Get filepaths |
| 120 | npy_file = os.path.join(os.path.dirname(__file__), npy_filename) |
| 121 | json_file = os.path.join(os.path.dirname(__file__), json_filename) |
| 122 | |
| 123 | # Save json data to file and reload it. |
| 124 | with open(json_file, "w") as f: |
| 125 | json.dump(json_data, f) |
| 126 | json_data = json.load(open(json_file)) |
| 127 | |
| 128 | args = [json_file] |
| 129 | """Converts json file to npy""" |
| 130 | assert main(args) == 0 |
| 131 | |
| 132 | npy_data = np.load(npy_file) |
| 133 | assert np.dtype(json_data["type"]) == npy_data.dtype |
| 134 | assert np.array(json_data["data"]).shape == npy_data.shape |
| 135 | assert (np.array(json_data["data"]) == npy_data).all() |
| 136 | |
| 137 | # Remove files created |
| 138 | if os.path.exists(npy_file): |
| 139 | os.remove(npy_file) |
| 140 | if os.path.exists(json_file): |
| 141 | os.remove(json_file) |