Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """Tests for json2numpy.py.""" |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 2 | # Copyright (c) 2021-2023, ARM Limited. |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 3 | # SPDX-License-Identifier: Apache-2.0 |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 4 | import os |
| 5 | |
| 6 | import numpy as np |
| 7 | import pytest |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 8 | from json2numpy.json2numpy import main |
| 9 | |
| 10 | |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 11 | DTYPE_RANGES = { |
| 12 | np.int8.__name__: [-128, 128], |
| 13 | np.uint8.__name__: [0, 256], |
| 14 | np.int16.__name__: [-32768, 32768], |
| 15 | np.uint16.__name__: [0, 65536], |
| 16 | np.int32.__name__: [-(1 << 31), (1 << 31)], |
| 17 | np.uint32.__name__: [0, (1 << 32)], |
| 18 | np.int64.__name__: [-(1 << 63), (1 << 63)], |
| 19 | np.uint64.__name__: [0, (1 << 64)], |
| 20 | } |
| 21 | |
| 22 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 23 | @pytest.mark.parametrize( |
| 24 | "npy_filename,json_filename,data_type", |
| 25 | [ |
| 26 | ("single_num.npy", "single_num.json", np.int8), |
| 27 | ("multiple_num.npy", "multiple_num.json", np.int8), |
| 28 | ("single_num.npy", "single_num.json", np.int16), |
| 29 | ("multiple_num.npy", "multiple_num.json", np.int16), |
| 30 | ("single_num.npy", "single_num.json", np.int32), |
| 31 | ("multiple_num.npy", "multiple_num.json", np.int32), |
| 32 | ("single_num.npy", "single_num.json", np.int64), |
| 33 | ("multiple_num.npy", "multiple_num.json", np.int64), |
| 34 | ("single_num.npy", "single_num.json", np.uint8), |
| 35 | ("multiple_num.npy", "multiple_num.json", np.uint8), |
| 36 | ("single_num.npy", "single_num.json", np.uint16), |
| 37 | ("multiple_num.npy", "multiple_num.json", np.uint16), |
| 38 | ("single_num.npy", "single_num.json", np.uint32), |
| 39 | ("multiple_num.npy", "multiple_num.json", np.uint32), |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 40 | # Not implemented due to json.dump issue |
| 41 | # ("single_num.npy", "single_num.json", np.uint64), |
| 42 | # ("multiple_num.npy", "multiple_num.json", np.uint64), |
| 43 | ("single_num.npy", "single_num.json", np.float16), |
| 44 | ("multiple_num.npy", "multiple_num.json", np.float16), |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 45 | ("single_num.npy", "single_num.json", np.float32), |
| 46 | ("multiple_num.npy", "multiple_num.json", np.float32), |
| 47 | ("single_num.npy", "single_num.json", np.float64), |
| 48 | ("multiple_num.npy", "multiple_num.json", np.float64), |
| 49 | ("single_num.npy", "single_num.json", bool), |
| 50 | ("multiple_num.npy", "multiple_num.json", bool), |
| 51 | ], |
| 52 | ) |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 53 | def test_json2numpy_there_and_back(npy_filename, json_filename, data_type): |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 54 | """Test conversion to JSON.""" |
| 55 | # Generate numpy data. |
| 56 | if "single" in npy_filename: |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 57 | shape = (1,) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 58 | elif "multiple" in npy_filename: |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 59 | shape = (4, 6, 5) |
| 60 | |
| 61 | rng = np.random.default_rng() |
| 62 | nan_location = None |
| 63 | if data_type in [np.float16, np.float32, np.float64]: |
| 64 | gen_type = np.float32 if data_type == np.float16 else data_type |
| 65 | generated_npy_data = rng.standard_normal(size=shape, dtype=gen_type).astype( |
| 66 | data_type |
| 67 | ) |
| 68 | if len(shape) > 1: |
| 69 | # Set some NANs and INFs |
| 70 | nan_location = (1, 2, 3) |
| 71 | generated_npy_data[nan_location] = np.nan |
| 72 | generated_npy_data[(3, 2, 1)] = np.inf |
| 73 | generated_npy_data[(0, 5, 2)] = -np.inf |
| 74 | elif data_type == bool: |
| 75 | generated_npy_data = rng.choice([True, False], size=shape).astype(bool) |
| 76 | else: |
| 77 | range = DTYPE_RANGES[data_type.__name__] |
| 78 | generated_npy_data = rng.integers( |
| 79 | low=range[0], high=range[1], size=shape, dtype=data_type |
| 80 | ) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 81 | |
| 82 | # Get filepaths |
| 83 | npy_file = os.path.join(os.path.dirname(__file__), npy_filename) |
| 84 | json_file = os.path.join(os.path.dirname(__file__), json_filename) |
| 85 | |
| 86 | # Save npy data to file and reload it. |
| 87 | with open(npy_file, "wb") as f: |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 88 | np.save(f, generated_npy_data) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 89 | npy_data = np.load(npy_file) |
| 90 | |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 91 | # Test json2numpy - converts npy file to json |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 92 | args = [npy_file] |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 93 | assert main(args) == 0 |
| 94 | |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 95 | # Remove the numpy file and convert json back to npy |
| 96 | os.remove(npy_file) |
| 97 | assert not os.path.exists(npy_file) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 98 | args = [json_file] |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 99 | assert main(args) == 0 |
| 100 | |
Jeremy Johnson | 898d3a2 | 2023-07-26 13:26:19 +0100 | [diff] [blame] | 101 | converted_npy_data = np.load(npy_file) |
| 102 | |
| 103 | # Check that the original data equals the npy->json->npy data |
| 104 | assert converted_npy_data.dtype == npy_data.dtype |
| 105 | assert converted_npy_data.shape == npy_data.shape |
| 106 | equals = np.equal(converted_npy_data, npy_data) |
| 107 | if nan_location is not None: |
| 108 | # NaNs do not usaually equal - so check and set |
| 109 | if np.isnan(converted_npy_data[nan_location]) and np.isnan( |
| 110 | npy_data[nan_location] |
| 111 | ): |
| 112 | equals[nan_location] = True |
| 113 | if not np.all(equals): |
| 114 | print("JSONed: ", converted_npy_data) |
| 115 | print("Original:", npy_data) |
| 116 | print("Equals: ", equals) |
| 117 | assert np.all(equals) |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 118 | |
| 119 | # Remove files created |
| 120 | if os.path.exists(npy_file): |
| 121 | os.remove(npy_file) |
| 122 | if os.path.exists(json_file): |
| 123 | os.remove(json_file) |