blob: 27ec4647dc718669528430e4529ba6c1f4453f5a [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001// Copyright (c) 2021, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <getopt.h>
16#include <iostream>
17#include <random>
18#include <sstream>
19#include <tosa_serialization_handler.h>
20
21using namespace tosa;
22
23void usage()
24{
25 std::cout << "Usage: serialization_npy_test -f <filename> -t <shape> -d <datatype> -s <seed>" << std::endl;
26}
27
28template <class T>
29int test_int_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
30{
31 size_t total_size = 1;
32 std::uniform_int_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
33
34 for (auto i : shape)
35 {
36 total_size *= i;
37 }
38
39 auto buffer = std::make_unique<T[]>(total_size);
40 for (int i = 0; i < total_size; i++)
41 {
42 buffer[i] = gen_data(gen);
43 }
44
45 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
46 if (err != NumpyUtilities::NO_ERROR)
47 {
48 std::cout << "Error writing file, code " << err << std::endl;
49 return 1;
50 }
51
52 auto read_buffer = std::make_unique<T[]>(total_size);
53 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
54 if (err != NumpyUtilities::NO_ERROR)
55 {
56 std::cout << "Error reading file, code " << err << std::endl;
57 return 1;
58 }
59 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
60 {
61 std::cout << "Miscompare" << std::endl;
62 return 1;
63 }
64 return 0;
65}
66
67template <class T>
68int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
69{
70 size_t total_size = 1;
71 std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
72
73 for (auto i : shape)
74 {
75 total_size *= i;
76 }
77
78 auto buffer = std::make_unique<T[]>(total_size);
79 for (int i = 0; i < total_size; i++)
80 {
81 buffer[i] = gen_data(gen);
82 }
83
84 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
85 if (err != NumpyUtilities::NO_ERROR)
86 {
87 std::cout << "Error writing file, code " << err << std::endl;
88 return 1;
89 }
90
91 auto read_buffer = std::make_unique<T[]>(total_size);
92 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
93 if (err != NumpyUtilities::NO_ERROR)
94 {
95 std::cout << "Error reading file, code " << err << std::endl;
96 return 1;
97 }
98 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
99 {
100 std::cout << "Miscompare" << std::endl;
101 return 1;
102 }
103 return 0;
104}
105
106int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
107{
108 size_t total_size = 1;
109 std::uniform_int_distribution<uint32_t> gen_data(0, 1);
110
111 for (auto i : shape)
112 {
113 total_size *= i;
114 }
115
116 auto buffer = std::make_unique<bool[]>(total_size);
117 for (int i = 0; i < total_size; i++)
118 {
119 buffer[i] = (gen_data(gen)) ? true : false;
120 }
121
122 NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
123 if (err != NumpyUtilities::NO_ERROR)
124 {
125 std::cout << "Error writing file, code " << err << std::endl;
126 return 1;
127 }
128
129 auto read_buffer = std::make_unique<bool[]>(total_size);
130 err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
131 if (err != NumpyUtilities::NO_ERROR)
132 {
133 std::cout << "Error reading file, code " << err << std::endl;
134 return 1;
135 }
136
137 if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(bool)))
138 {
139 std::cout << "Miscompare" << std::endl;
140 return 1;
141 }
142 return 0;
143}
144
145int main(int argc, char** argv)
146{
147 size_t total_size = 1;
148 int32_t seed = 1;
149 std::string str_type;
150 std::string str_shape;
151 std::string filename = "npytest.npy";
152 std::vector<int32_t> shape;
153 bool verbose = false;
154 int opt;
155 while ((opt = getopt(argc, argv, "d:f:s:t:v")) != -1)
156 {
157 switch (opt)
158 {
159 case 'd':
160 str_type = optarg;
161 break;
162 case 'f':
163 filename = optarg;
164 break;
165 case 's':
166 seed = strtol(optarg, nullptr, 0);
167 break;
168 case 't':
169 str_shape = optarg;
170 break;
171 case 'v':
172 verbose = true;
173 break;
174 default:
175 std::cerr << "Invalid argument" << std::endl;
176 break;
177 }
178 }
179 if (str_shape == "")
180 {
181 usage();
182 return 1;
183 }
184
185 // parse shape from argument
186 std::stringstream ss(str_shape);
187 while (ss.good())
188 {
189 std::string substr;
190 size_t pos;
191 std::getline(ss, substr, ',');
192 if (substr == "")
193 break;
194 int val = stoi(substr, &pos, 0);
195 assert(val);
196 total_size *= val;
197 shape.push_back(val);
198 }
199
200 std::default_random_engine gen(seed);
201
202 // run with type from argument
203 if (str_type == "int32")
204 {
205 return test_int_type<int32_t>(shape, gen, filename);
206 }
207 else if (str_type == "int64")
208 {
209 return test_int_type<int64_t>(shape, gen, filename);
210 }
211 else if (str_type == "float")
212 {
213 return test_float_type<float>(shape, gen, filename);
214 }
215 else if (str_type == "bool")
216 {
217 return test_bool_type(shape, gen, filename);
218 }
219 else
220 {
221 std::cout << "Unknown type " << str_type << std::endl;
222 usage();
223 return 1;
224 }
225}