blob: 24e3aff25b0cbb56decd15f834fd240c5a35116a [file] [log] [blame]
Eric Kunzecc426df2024-01-03 00:27:59 +00001// Copyright (c) 2021,2024, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07002//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <getopt.h>
16#include <iostream>
17#include <random>
18#include <sstream>
19#include <tosa_serialization_handler.h>
20
21using namespace tosa;
22
23void usage()
24{
25 std::cout << "Usage: serialization_npy_test -f <filename> -t <shape> -d <datatype> -s <seed>" << std::endl;
26}
27
28template <class T>
29int test_int_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
30{
31 size_t total_size = 1;
32 std::uniform_int_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
33
34 for (auto i : shape)
35 {
36 total_size *= i;
37 }
38
39 auto buffer = std::make_unique<T[]>(total_size);
Eric Kunzecc426df2024-01-03 00:27:59 +000040 for (size_t i = 0; i < total_size; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -070041 {
42 buffer[i] = gen_data(gen);
43 }
44
45 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
46 if (err != NumpyUtilities::NO_ERROR)
47 {
48 std::cout << "Error writing file, code " << err << std::endl;
49 return 1;
50 }
51
52 auto read_buffer = std::make_unique<T[]>(total_size);
53 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
54 if (err != NumpyUtilities::NO_ERROR)
55 {
56 std::cout << "Error reading file, code " << err << std::endl;
57 return 1;
58 }
59 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
60 {
61 std::cout << "Miscompare" << std::endl;
62 return 1;
63 }
64 return 0;
65}
66
67template <class T>
68int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
69{
70 size_t total_size = 1;
71 std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
72
73 for (auto i : shape)
74 {
75 total_size *= i;
76 }
77
78 auto buffer = std::make_unique<T[]>(total_size);
Eric Kunzecc426df2024-01-03 00:27:59 +000079 for (size_t i = 0; i < total_size; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -070080 {
81 buffer[i] = gen_data(gen);
82 }
83
84 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
85 if (err != NumpyUtilities::NO_ERROR)
86 {
87 std::cout << "Error writing file, code " << err << std::endl;
88 return 1;
89 }
90
91 auto read_buffer = std::make_unique<T[]>(total_size);
92 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
93 if (err != NumpyUtilities::NO_ERROR)
94 {
95 std::cout << "Error reading file, code " << err << std::endl;
96 return 1;
97 }
98 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
99 {
100 std::cout << "Miscompare" << std::endl;
101 return 1;
102 }
103 return 0;
104}
105
Tai Ly3ef34fb2023-04-04 20:34:05 +0000106template <class T>
107int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
108{
109 size_t total_size = 1;
110 std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
111
112 for (auto i : shape)
113 {
114 total_size *= i;
115 }
116
117 auto buffer = std::make_unique<T[]>(total_size);
Eric Kunzecc426df2024-01-03 00:27:59 +0000118 for (size_t i = 0; i < total_size; i++)
Tai Ly3ef34fb2023-04-04 20:34:05 +0000119 {
120 buffer[i] = gen_data(gen);
121 }
122
123 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
124 if (err != NumpyUtilities::NO_ERROR)
125 {
126 std::cout << "Error writing file, code " << err << std::endl;
127 return 1;
128 }
129
130 auto read_buffer = std::make_unique<T[]>(total_size);
131 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
132 if (err != NumpyUtilities::NO_ERROR)
133 {
134 std::cout << "Error reading file, code " << err << std::endl;
135 return 1;
136 }
137 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
138 {
139 std::cout << "Miscompare" << std::endl;
140 return 1;
141 }
142 return 0;
143}
144
Eric Kunze2364dcd2021-04-26 11:06:57 -0700145int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
146{
147 size_t total_size = 1;
148 std::uniform_int_distribution<uint32_t> gen_data(0, 1);
149
150 for (auto i : shape)
151 {
152 total_size *= i;
153 }
154
155 auto buffer = std::make_unique<bool[]>(total_size);
Eric Kunzecc426df2024-01-03 00:27:59 +0000156 for (size_t i = 0; i < total_size; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700157 {
158 buffer[i] = (gen_data(gen)) ? true : false;
159 }
160
161 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
162 if (err != NumpyUtilities::NO_ERROR)
163 {
164 std::cout << "Error writing file, code " << err << std::endl;
165 return 1;
166 }
167
168 auto read_buffer = std::make_unique<bool[]>(total_size);
169 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
170 if (err != NumpyUtilities::NO_ERROR)
171 {
172 std::cout << "Error reading file, code " << err << std::endl;
173 return 1;
174 }
175
176 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(bool)))
177 {
178 std::cout << "Miscompare" << std::endl;
179 return 1;
180 }
181 return 0;
182}
183
184int main(int argc, char** argv)
185{
Eric Kunzecc426df2024-01-03 00:27:59 +0000186 int32_t seed = 1;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700187 std::string str_type;
188 std::string str_shape;
189 std::string filename = "npytest.npy";
190 std::vector<int32_t> shape;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700191 int opt;
Eric Kunzecc426df2024-01-03 00:27:59 +0000192 while ((opt = getopt(argc, argv, "d:f:s:t:")) != -1)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700193 {
194 switch (opt)
195 {
196 case 'd':
197 str_type = optarg;
198 break;
199 case 'f':
200 filename = optarg;
201 break;
202 case 's':
203 seed = strtol(optarg, nullptr, 0);
204 break;
205 case 't':
206 str_shape = optarg;
207 break;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700208 default:
209 std::cerr << "Invalid argument" << std::endl;
210 break;
211 }
212 }
213 if (str_shape == "")
214 {
215 usage();
216 return 1;
217 }
218
219 // parse shape from argument
220 std::stringstream ss(str_shape);
221 while (ss.good())
222 {
223 std::string substr;
224 size_t pos;
225 std::getline(ss, substr, ',');
226 if (substr == "")
227 break;
228 int val = stoi(substr, &pos, 0);
229 assert(val);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700230 shape.push_back(val);
231 }
232
233 std::default_random_engine gen(seed);
234
235 // run with type from argument
236 if (str_type == "int32")
237 {
238 return test_int_type<int32_t>(shape, gen, filename);
239 }
240 else if (str_type == "int64")
241 {
242 return test_int_type<int64_t>(shape, gen, filename);
243 }
244 else if (str_type == "float")
245 {
246 return test_float_type<float>(shape, gen, filename);
247 }
Tai Ly3ef34fb2023-04-04 20:34:05 +0000248 else if (str_type == "double")
249 {
250 return test_double_type<double>(shape, gen, filename);
251 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700252 else if (str_type == "bool")
253 {
254 return test_bool_type(shape, gen, filename);
255 }
256 else
257 {
258 std::cout << "Unknown type " << str_type << std::endl;
259 usage();
260 return 1;
261 }
262}