blob: ae8d2a7b14be4f06bbb02c5c9e91abda0ef82f2e [file] [log] [blame]
Kristofer Jonsson116a6352020-08-20 17:25:23 +02001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
Kristofer Jonsson116a6352020-08-20 17:25:23 +020019
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020020#include <ethosu.hpp>
Kristofer Jonsson116a6352020-08-20 17:25:23 +020021#include <uapi/ethosu.h>
22
23#include <unistd.h>
24#include <fstream>
25#include <iomanip>
26#include <iostream>
27#include <list>
28#include <string>
29
30using namespace std;
31using namespace EthosU;
32
33namespace
34{
35int defaultTimeout = 60;
36
37void help(const string exe)
38{
39 cerr << "Usage: " << exe << " [ARGS]\n";
40 cerr << "\n";
41 cerr << "Arguments:\n";
42 cerr << " -h --help Print this help message.\n";
43 cerr << " -n --network File to read network from.\n";
44 cerr << " -i --ifm File to read IFM from.\n";
45 cerr << " -o --ofm File to write IFM to.\n";
46 cerr << " -t --timeout Timeout in seconds (default " << defaultTimeout << ").\n";
47 cerr << " -p Print OFM.\n";
48 cerr << endl;
49}
50
51void rangeCheck(const int i, const int argc, const string arg)
52{
53 if (i >= argc)
54 {
55 cerr << "Error: Missing argument to '" << arg << "'" << endl;
56 exit(1);
57 }
58}
59
60shared_ptr<Buffer> allocAndFill(Device &device, const string filename)
61{
62 ifstream stream(filename, ios::binary);
63 if (!stream.is_open())
64 {
65 cerr << "Error: Failed to open '" << filename << "'" << endl;
66 exit(1);
67 }
68
69 stream.seekg(0, ios_base::end);
70 size_t size = stream.tellg();
71 stream.seekg(0, ios_base::beg);
72
73 shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
74 buffer->resize(size);
75 stream.read(buffer->data(), size);
76
77 return buffer;
78}
79
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020080shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename)
81{
82 // Open IFM file
83 ifstream stream(filename, ios::binary);
84 if (!stream.is_open())
85 {
86 cerr << "Error: Failed to open '" << filename << "'" << endl;
87 exit(1);
88 }
89
90 // Get IFM file size
91 stream.seekg(0, ios_base::end);
92 size_t size = stream.tellg();
93 stream.seekg(0, ios_base::beg);
94
95 if (size != network->getIfmSize())
96 {
97 cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size << ", network=" << network->getIfmSize() << endl;
98 exit(1);
99 }
100
101 // Create IFM buffers
102 vector<shared_ptr<Buffer>> ifm;
103 for (auto size: network->getIfmDims())
104 {
105 shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
106 buffer->resize(size);
107 stream.read(buffer->data(), size);
108
109 if (!stream)
110 {
111 cerr << "Error: Failed to read IFM" << endl;
112 exit(1);
113 }
114
115 ifm.push_back(buffer);
116 }
117
118 // Create OFM buffers
119 vector<shared_ptr<Buffer>> ofm;
120 for (auto size: network->getOfmDims())
121 {
122 ofm.push_back(make_shared<Buffer>(device, size));
123 }
124
125 return make_shared<Inference>(network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end());
126}
127
128ostream &operator<<(ostream &os, Buffer &buf)
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200129{
130 char *c = buf.data();
131 const char *end = c + buf.size();
132
133 while (c < end)
134 {
135 os << hex << setw(2) << static_cast<int>(*c++) << " " << dec;
136 }
137
138 return os;
139}
140
141}
142
143int main(int argc, char *argv[])
144{
145 const string exe = argv[0];
146 string networkArg;
147 list<string> ifmArg;
148 string ofmArg;
149 int timeout = defaultTimeout;
150 bool print = false;
151
152 for (int i = 1; i < argc; ++i)
153 {
154 const string arg(argv[i]);
155
156 if (arg == "-h" || arg == "--help")
157 {
158 help(exe);
159 exit(1);
160 }
161 else if (arg == "--network" || arg == "-n")
162 {
163 rangeCheck(++i, argc, arg);
164 networkArg = argv[i];
165 }
166 else if (arg == "--ifm" || arg == "-i")
167 {
168 rangeCheck(++i, argc, arg);
169 ifmArg.push_back(argv[i]);
170 }
171 else if (arg == "--ofm" || arg == "-o")
172 {
173 rangeCheck(++i, argc, arg);
174 ofmArg = argv[i];
175 }
176 else if (arg == "--timeout" || arg == "-t")
177 {
178 rangeCheck(++i, argc, arg);
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200179 timeout = stoi(argv[i]);
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200180 }
181 else if (arg == "-p")
182 {
183 print = true;
184 }
185 else
186 {
187 cerr << "Error: Invalid argument '" << arg << "'" << endl;
188 help(exe);
189 exit(1);
190 }
191 }
192
193 if (networkArg.empty())
194 {
195 cerr << "Error: Missing 'network' argument" << endl;
196 exit(1);
197 }
198
199 if (ifmArg.empty())
200 {
201 cerr << "Error: Missing 'ifm' argument" << endl;
202 exit(1);
203 }
204
205 if (ofmArg.empty())
206 {
207 cerr << "Error: Missing 'ofm' argument" << endl;
208 exit(1);
209 }
210
211 try
212 {
213 Device device;
214
215 cout << "Send ping" << endl;
216 device.ioctl(ETHOSU_IOCTL_PING);
217
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200218 /* Create network */
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200219 cout << "Create network" << endl;
220 shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg);
221 shared_ptr<Network> network = make_shared<Network>(device, networkBuffer);
222
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200223 /* Create one inference per IFM */
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200224 list<shared_ptr<Inference>> inferences;
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200225 for (auto &filename: ifmArg)
226 {
227 cout << "Create inference" << endl;
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200228 inferences.push_back(createInference(device, network, filename));
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200229 }
230
231 cout << "Wait for inferences" << endl;
232
233 int ofmIndex = 0;
234 for (auto &inference: inferences)
235 {
236 inference->wait(timeout);
237
238 string status = inference->failed() ? "failed" : "success";
239 cout << "Inference status: " << status << endl;
240
241 string ofmFilename = ofmArg + "." + to_string(ofmIndex);
242 ofstream ofmStream(ofmFilename, ios::binary);
243 if (!ofmStream.is_open())
244 {
245 cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
246 exit(1);
247 }
248
249 if (!inference->failed())
250 {
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200251 for (auto &ofmBuffer: inference->getOfmBuffers())
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200252 {
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200253 cout << "OFM size: " << ofmBuffer->size() << endl;
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200254
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200255 if (print)
256 {
257 cout << "OFM data: " << *ofmBuffer << endl;
258 }
259
260 ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
261 }
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200262 }
263
264 ofmIndex++;
265 }
266 }
267 catch (Exception &e)
268 {
269 cerr << "Error: " << e.what() << endl;
270 return 1;
271 }
272
273 return 0;
274}