blob: 63bc2d99c3ea5f2f755a2ad2edb84065c8f3bcfb [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""Tests for json2numpy.py."""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import json
5import os
6
7import numpy as np
8import pytest
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00009from json2numpy.json2numpy import main
10
11
12@pytest.mark.parametrize(
13 "npy_filename,json_filename,data_type",
14 [
15 ("single_num.npy", "single_num.json", np.int8),
16 ("multiple_num.npy", "multiple_num.json", np.int8),
17 ("single_num.npy", "single_num.json", np.int16),
18 ("multiple_num.npy", "multiple_num.json", np.int16),
19 ("single_num.npy", "single_num.json", np.int32),
20 ("multiple_num.npy", "multiple_num.json", np.int32),
21 ("single_num.npy", "single_num.json", np.int64),
22 ("multiple_num.npy", "multiple_num.json", np.int64),
23 ("single_num.npy", "single_num.json", np.uint8),
24 ("multiple_num.npy", "multiple_num.json", np.uint8),
25 ("single_num.npy", "single_num.json", np.uint16),
26 ("multiple_num.npy", "multiple_num.json", np.uint16),
27 ("single_num.npy", "single_num.json", np.uint32),
28 ("multiple_num.npy", "multiple_num.json", np.uint32),
29 ("single_num.npy", "single_num.json", np.uint64),
30 ("multiple_num.npy", "multiple_num.json", np.uint64),
31 ("single_num.npy", "single_num.json", np.float16),
32 ("multiple_num.npy", "multiple_num.json", np.float16),
33 ("single_num.npy", "single_num.json", np.float32),
34 ("multiple_num.npy", "multiple_num.json", np.float32),
35 ("single_num.npy", "single_num.json", np.float64),
36 ("multiple_num.npy", "multiple_num.json", np.float64),
37 ("single_num.npy", "single_num.json", bool),
38 ("multiple_num.npy", "multiple_num.json", bool),
39 ],
40)
41def test_json2numpy_npy_file(npy_filename, json_filename, data_type):
42 """Test conversion to JSON."""
43 # Generate numpy data.
44 if "single" in npy_filename:
45 npy_data = np.ndarray(shape=(1, 1), dtype=data_type)
46 elif "multiple" in npy_filename:
47 npy_data = np.ndarray(shape=(2, 3), dtype=data_type)
48
49 # Get filepaths
50 npy_file = os.path.join(os.path.dirname(__file__), npy_filename)
51 json_file = os.path.join(os.path.dirname(__file__), json_filename)
52
53 # Save npy data to file and reload it.
54 with open(npy_file, "wb") as f:
55 np.save(f, npy_data)
56 npy_data = np.load(npy_file)
57
58 args = [npy_file]
59 """Converts npy file to json"""
60 assert main(args) == 0
61
62 json_data = json.load(open(json_file))
63 assert np.dtype(json_data["type"]) == npy_data.dtype
64 assert np.array(json_data["data"]).shape == npy_data.shape
65 assert (np.array(json_data["data"]) == npy_data).all()
66
67 # Remove files created
68 if os.path.exists(npy_file):
69 os.remove(npy_file)
70 if os.path.exists(json_file):
71 os.remove(json_file)
72
73
74@pytest.mark.parametrize(
75 "npy_filename,json_filename,data_type",
76 [
77 ("single_num.npy", "single_num.json", np.int8),
78 ("multiple_num.npy", "multiple_num.json", np.int8),
79 ("single_num.npy", "single_num.json", np.int16),
80 ("multiple_num.npy", "multiple_num.json", np.int16),
81 ("single_num.npy", "single_num.json", np.int32),
82 ("multiple_num.npy", "multiple_num.json", np.int32),
83 ("single_num.npy", "single_num.json", np.int64),
84 ("multiple_num.npy", "multiple_num.json", np.int64),
85 ("single_num.npy", "single_num.json", np.uint8),
86 ("multiple_num.npy", "multiple_num.json", np.uint8),
87 ("single_num.npy", "single_num.json", np.uint16),
88 ("multiple_num.npy", "multiple_num.json", np.uint16),
89 ("single_num.npy", "single_num.json", np.uint32),
90 ("multiple_num.npy", "multiple_num.json", np.uint32),
91 ("single_num.npy", "single_num.json", np.uint64),
92 ("multiple_num.npy", "multiple_num.json", np.uint64),
93 ("single_num.npy", "single_num.json", np.float16),
94 ("multiple_num.npy", "multiple_num.json", np.float16),
95 ("single_num.npy", "single_num.json", np.float32),
96 ("multiple_num.npy", "multiple_num.json", np.float32),
97 ("single_num.npy", "single_num.json", np.float64),
98 ("multiple_num.npy", "multiple_num.json", np.float64),
99 ("single_num.npy", "single_num.json", bool),
100 ("multiple_num.npy", "multiple_num.json", bool),
101 ],
102)
103def test_json2numpy_json_file(npy_filename, json_filename, data_type):
104 """Test conversion to binary."""
105 # Generate json data.
106 if "single" in npy_filename:
107 npy_data = np.ndarray(shape=(1, 1), dtype=data_type)
108 elif "multiple" in npy_filename:
109 npy_data = np.ndarray(shape=(2, 3), dtype=data_type)
110
111 # Generate json dictionary
112 list_data = npy_data.tolist()
113 json_data_type = str(npy_data.dtype)
114
115 json_data = {}
116 json_data["type"] = json_data_type
117 json_data["data"] = list_data
118
119 # Get filepaths
120 npy_file = os.path.join(os.path.dirname(__file__), npy_filename)
121 json_file = os.path.join(os.path.dirname(__file__), json_filename)
122
123 # Save json data to file and reload it.
124 with open(json_file, "w") as f:
125 json.dump(json_data, f)
126 json_data = json.load(open(json_file))
127
128 args = [json_file]
129 """Converts json file to npy"""
130 assert main(args) == 0
131
132 npy_data = np.load(npy_file)
133 assert np.dtype(json_data["type"]) == npy_data.dtype
134 assert np.array(json_data["data"]).shape == npy_data.shape
135 assert (np.array(json_data["data"]) == npy_data).all()
136
137 # Remove files created
138 if os.path.exists(npy_file):
139 os.remove(npy_file)
140 if os.path.exists(json_file):
141 os.remove(json_file)