blob: c04013e8afb16c7425def89ecad914b3a7c7cfcc [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""Conversion utility from binary numpy files to JSON and the reverse."""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import json
5from pathlib import Path
6from typing import Optional
7from typing import Union
8
9import numpy as np
10
11
12class NumpyArrayEncoder(json.JSONEncoder):
13 """A JSON encoder for Numpy data types."""
14
15 def default(self, obj):
16 """Encode default operation."""
17 if isinstance(obj, np.integer):
18 return int(obj)
James Ward8b390432022-08-12 20:48:56 +010019 elif isinstance(obj, np.float32):
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000020 return float(obj)
James Ward8b390432022-08-12 20:48:56 +010021 elif isinstance(obj, np.float16):
22 return np.float16(obj)
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000023 elif isinstance(obj, np.ndarray):
24 return obj.tolist()
25 return super(NumpyArrayEncoder, self).default(obj)
26
27
28def get_shape(t: Union[list, tuple]):
29 """Get the shape of an N-Dimensional tensor."""
30 # TODO: validate shape is consistent for all rows and ccolumns
31 if isinstance(t, (list, tuple)) and t:
32 return [len(t)] + get_shape(t[0])
33 return []
34
35
36def npy_to_json(n_path: Path, j_path: Optional[Path] = None):
37 """Load a numpy data file and save it as a JSON file.
38
39 n_path: the Path to the numpy file
40 j_path: the Path to the JSON file, if None, it is derived from n_path
41 """
42 if not j_path:
43 j_path = n_path.parent / (n_path.stem + ".json")
44 with open(n_path, "rb") as fd:
45 data = np.load(fd)
46 jdata = {
47 "type": data.dtype.name,
48 "data": data.tolist(),
49 }
50 with open(j_path, "w") as fp:
51 json.dump(jdata, fp, indent=2)
52
53
54def json_to_npy(j_path: Path, n_path: Optional[Path] = None):
55 """Load a JSON file and save it as a numpy data file.
56
57 j_path: the Path to the JSON file
58 n_path: the Path to the numpy file, if None, it is derived from j_path
59 """
60 if not n_path:
61 n_path = j_path.parent / (j_path.stem + ".npy")
62 with open(j_path, "rb") as fd:
63 jdata = json.load(fd)
64 raw_data = jdata["data"]
65 raw_type = jdata["type"]
66 shape = get_shape(raw_data)
67 data = np.asarray(raw_data).reshape(shape).astype(raw_type)
68 with open(n_path, "wb") as fd:
69 np.save(fd, data)
70
71
72# ------------------------------------------------------------------------------
73
74
75def test():
76 """Test conversion routines."""
77 shape = [2, 3, 4]
78 elements = 1
79 for i in shape:
80 elements *= i
81
82 # file names
83 n_path = Path("data.npy")
84 j_path = Path("data.json")
85 j2n_path = Path("data_j2n.npy")
86
87 datatypes = [
88 np.bool_,
89 np.int8,
90 np.int16,
91 np.int32,
92 np.int64,
93 np.uint8,
94 np.uint16,
95 np.uint32,
96 np.uint64,
97 np.float16,
98 np.float32,
99 np.float64,
100 # np.float128,
101 # np.complex64,
102 # np.complex128,
103 # np.complex256,
104 # np.datetime64,
105 # np.str,
106 ]
107
108 for data_type in datatypes:
109 dt = np.dtype(data_type)
110 print(data_type, dt, dt.char, dt.num, dt.name, dt.str)
111
112 # create a tensor of the given shape
113 tensor = np.arange(elements).reshape(shape).astype(data_type)
114 # print(tensor)
115
116 # save the tensor in a binary numpy file
117 with open(n_path, "wb") as fd:
118 np.save(fd, tensor)
119
120 # read back the numpy file for verification
121 with open(n_path, "rb") as fd:
122 tensor1 = np.load(fd)
123
124 # confirm the loaded tensor matches the original
125 assert tensor.shape == tensor1.shape
126 assert tensor.dtype == tensor1.dtype
127 assert (tensor == tensor1).all()
128
129 # convert the numpy file to json
130 npy_to_json(n_path, j_path)
131
132 # convert the json file to numpy
133 json_to_npy(j_path, j2n_path)
134
135 # read back the json-to-numpy file for verification
136 with open(j2n_path, "rb") as fd:
137 tensor1 = np.load(fd)
138
139 # confirm the loaded tensor matches the original
140 assert tensor.shape == tensor1.shape
141 assert tensor.dtype == tensor1.dtype
142 assert (tensor == tensor1).all()
143
144 # delete the files, if no problems were found
145 # they are left for debugging if any of the asserts failed
146 n_path.unlink()
147 j_path.unlink()
148 j2n_path.unlink()
149 return 0
150
151
152def main(argv=None):
153 """Load and convert supplied file based on file suffix."""
154 import argparse
155
156 parser = argparse.ArgumentParser()
157 parser.add_argument(
158 "path", type=Path, help="the path to the file to convert, or 'test'"
159 )
160 args = parser.parse_args(argv)
161 path = args.path
162 if str(path) == "test":
163 print("test")
164 return test()
165
166 if not path.is_file():
167 print(f"Invalid file - {path}")
168 return 2
169
170 if path.suffix == ".npy":
171 npy_to_json(path)
172 elif path.suffix == ".json":
173 json_to_npy(path)
174 else:
175 print("Unknown file type - {path.suffix}")
176 return 2
177
178 return 0
179
180
181if __name__ == "__main__":
182 exit(main())