blob: b01ebe93d99ef5bde567e1f53f58364e49dd985e [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""Tests for json2numpy.py."""
Jeremy Johnson898d3a22023-07-26 13:26:19 +01002# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00003# SPDX-License-Identifier: Apache-2.0
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00004import os
5
6import numpy as np
7import pytest
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00008from json2numpy.json2numpy import main
9
10
Jeremy Johnson898d3a22023-07-26 13:26:19 +010011DTYPE_RANGES = {
12 np.int8.__name__: [-128, 128],
13 np.uint8.__name__: [0, 256],
14 np.int16.__name__: [-32768, 32768],
15 np.uint16.__name__: [0, 65536],
16 np.int32.__name__: [-(1 << 31), (1 << 31)],
17 np.uint32.__name__: [0, (1 << 32)],
18 np.int64.__name__: [-(1 << 63), (1 << 63)],
19 np.uint64.__name__: [0, (1 << 64)],
20}
21
22
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000023@pytest.mark.parametrize(
24 "npy_filename,json_filename,data_type",
25 [
26 ("single_num.npy", "single_num.json", np.int8),
27 ("multiple_num.npy", "multiple_num.json", np.int8),
28 ("single_num.npy", "single_num.json", np.int16),
29 ("multiple_num.npy", "multiple_num.json", np.int16),
30 ("single_num.npy", "single_num.json", np.int32),
31 ("multiple_num.npy", "multiple_num.json", np.int32),
32 ("single_num.npy", "single_num.json", np.int64),
33 ("multiple_num.npy", "multiple_num.json", np.int64),
34 ("single_num.npy", "single_num.json", np.uint8),
35 ("multiple_num.npy", "multiple_num.json", np.uint8),
36 ("single_num.npy", "single_num.json", np.uint16),
37 ("multiple_num.npy", "multiple_num.json", np.uint16),
38 ("single_num.npy", "single_num.json", np.uint32),
39 ("multiple_num.npy", "multiple_num.json", np.uint32),
Jeremy Johnson898d3a22023-07-26 13:26:19 +010040 # Not implemented due to json.dump issue
41 # ("single_num.npy", "single_num.json", np.uint64),
42 # ("multiple_num.npy", "multiple_num.json", np.uint64),
43 ("single_num.npy", "single_num.json", np.float16),
44 ("multiple_num.npy", "multiple_num.json", np.float16),
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000045 ("single_num.npy", "single_num.json", np.float32),
46 ("multiple_num.npy", "multiple_num.json", np.float32),
47 ("single_num.npy", "single_num.json", np.float64),
48 ("multiple_num.npy", "multiple_num.json", np.float64),
49 ("single_num.npy", "single_num.json", bool),
50 ("multiple_num.npy", "multiple_num.json", bool),
51 ],
52)
Jeremy Johnson898d3a22023-07-26 13:26:19 +010053def test_json2numpy_there_and_back(npy_filename, json_filename, data_type):
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000054 """Test conversion to JSON."""
55 # Generate numpy data.
56 if "single" in npy_filename:
Jeremy Johnson898d3a22023-07-26 13:26:19 +010057 shape = (1,)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000058 elif "multiple" in npy_filename:
Jeremy Johnson898d3a22023-07-26 13:26:19 +010059 shape = (4, 6, 5)
60
61 rng = np.random.default_rng()
62 nan_location = None
63 if data_type in [np.float16, np.float32, np.float64]:
64 gen_type = np.float32 if data_type == np.float16 else data_type
65 generated_npy_data = rng.standard_normal(size=shape, dtype=gen_type).astype(
66 data_type
67 )
68 if len(shape) > 1:
69 # Set some NANs and INFs
70 nan_location = (1, 2, 3)
71 generated_npy_data[nan_location] = np.nan
72 generated_npy_data[(3, 2, 1)] = np.inf
73 generated_npy_data[(0, 5, 2)] = -np.inf
74 elif data_type == bool:
75 generated_npy_data = rng.choice([True, False], size=shape).astype(bool)
76 else:
77 range = DTYPE_RANGES[data_type.__name__]
78 generated_npy_data = rng.integers(
79 low=range[0], high=range[1], size=shape, dtype=data_type
80 )
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000081
82 # Get filepaths
83 npy_file = os.path.join(os.path.dirname(__file__), npy_filename)
84 json_file = os.path.join(os.path.dirname(__file__), json_filename)
85
86 # Save npy data to file and reload it.
87 with open(npy_file, "wb") as f:
Jeremy Johnson898d3a22023-07-26 13:26:19 +010088 np.save(f, generated_npy_data)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000089 npy_data = np.load(npy_file)
90
Jeremy Johnson898d3a22023-07-26 13:26:19 +010091 # Test json2numpy - converts npy file to json
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000092 args = [npy_file]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000093 assert main(args) == 0
94
Jeremy Johnson898d3a22023-07-26 13:26:19 +010095 # Remove the numpy file and convert json back to npy
96 os.remove(npy_file)
97 assert not os.path.exists(npy_file)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000098 args = [json_file]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000099 assert main(args) == 0
100
Jeremy Johnson898d3a22023-07-26 13:26:19 +0100101 converted_npy_data = np.load(npy_file)
102
103 # Check that the original data equals the npy->json->npy data
104 assert converted_npy_data.dtype == npy_data.dtype
105 assert converted_npy_data.shape == npy_data.shape
106 equals = np.equal(converted_npy_data, npy_data)
107 if nan_location is not None:
108 # NaNs do not usaually equal - so check and set
109 if np.isnan(converted_npy_data[nan_location]) and np.isnan(
110 npy_data[nan_location]
111 ):
112 equals[nan_location] = True
113 if not np.all(equals):
114 print("JSONed: ", converted_npy_data)
115 print("Original:", npy_data)
116 print("Equals: ", equals)
117 assert np.all(equals)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +0000118
119 # Remove files created
120 if os.path.exists(npy_file):
121 os.remove(npy_file)
122 if os.path.exists(json_file):
123 os.remove(json_file)