blob: 338817051393f0de0c409ea78ee399ef3d65c006 [file] [log] [blame]
Kristofer Jonsson116a6352020-08-20 17:25:23 +02001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#include <ethosu.hpp>
20
21#include <uapi/ethosu.h>
22
23#include <unistd.h>
24#include <fstream>
25#include <iomanip>
26#include <iostream>
27#include <list>
28#include <string>
29
30using namespace std;
31using namespace EthosU;
32
33namespace
34{
35int defaultTimeout = 60;
36
37void help(const string exe)
38{
39 cerr << "Usage: " << exe << " [ARGS]\n";
40 cerr << "\n";
41 cerr << "Arguments:\n";
42 cerr << " -h --help Print this help message.\n";
43 cerr << " -n --network File to read network from.\n";
44 cerr << " -i --ifm File to read IFM from.\n";
45 cerr << " -o --ofm File to write IFM to.\n";
46 cerr << " -t --timeout Timeout in seconds (default " << defaultTimeout << ").\n";
47 cerr << " -p Print OFM.\n";
48 cerr << endl;
49}
50
51void rangeCheck(const int i, const int argc, const string arg)
52{
53 if (i >= argc)
54 {
55 cerr << "Error: Missing argument to '" << arg << "'" << endl;
56 exit(1);
57 }
58}
59
60shared_ptr<Buffer> allocAndFill(Device &device, const string filename)
61{
62 ifstream stream(filename, ios::binary);
63 if (!stream.is_open())
64 {
65 cerr << "Error: Failed to open '" << filename << "'" << endl;
66 exit(1);
67 }
68
69 stream.seekg(0, ios_base::end);
70 size_t size = stream.tellg();
71 stream.seekg(0, ios_base::beg);
72
73 shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
74 buffer->resize(size);
75 stream.read(buffer->data(), size);
76
77 return buffer;
78}
79
80std::ostream &operator<<(std::ostream &os, Buffer &buf)
81{
82 char *c = buf.data();
83 const char *end = c + buf.size();
84
85 while (c < end)
86 {
87 os << hex << setw(2) << static_cast<int>(*c++) << " " << dec;
88 }
89
90 return os;
91}
92
93}
94
95int main(int argc, char *argv[])
96{
97 const string exe = argv[0];
98 string networkArg;
99 list<string> ifmArg;
100 string ofmArg;
101 int timeout = defaultTimeout;
102 bool print = false;
103
104 for (int i = 1; i < argc; ++i)
105 {
106 const string arg(argv[i]);
107
108 if (arg == "-h" || arg == "--help")
109 {
110 help(exe);
111 exit(1);
112 }
113 else if (arg == "--network" || arg == "-n")
114 {
115 rangeCheck(++i, argc, arg);
116 networkArg = argv[i];
117 }
118 else if (arg == "--ifm" || arg == "-i")
119 {
120 rangeCheck(++i, argc, arg);
121 ifmArg.push_back(argv[i]);
122 }
123 else if (arg == "--ofm" || arg == "-o")
124 {
125 rangeCheck(++i, argc, arg);
126 ofmArg = argv[i];
127 }
128 else if (arg == "--timeout" || arg == "-t")
129 {
130 rangeCheck(++i, argc, arg);
131 timeout = std::stoi(argv[i]);
132 }
133 else if (arg == "-p")
134 {
135 print = true;
136 }
137 else
138 {
139 cerr << "Error: Invalid argument '" << arg << "'" << endl;
140 help(exe);
141 exit(1);
142 }
143 }
144
145 if (networkArg.empty())
146 {
147 cerr << "Error: Missing 'network' argument" << endl;
148 exit(1);
149 }
150
151 if (ifmArg.empty())
152 {
153 cerr << "Error: Missing 'ifm' argument" << endl;
154 exit(1);
155 }
156
157 if (ofmArg.empty())
158 {
159 cerr << "Error: Missing 'ofm' argument" << endl;
160 exit(1);
161 }
162
163 try
164 {
165 Device device;
166
167 cout << "Send ping" << endl;
168 device.ioctl(ETHOSU_IOCTL_PING);
169
170 cout << "Create network" << endl;
171 shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg);
172 shared_ptr<Network> network = make_shared<Network>(device, networkBuffer);
173
174 cout << "Queue inferences" << endl;
175 list<shared_ptr<Inference>> inferences;
176
177 for (auto &filename: ifmArg)
178 {
179 cout << "Create inference" << endl;
180 shared_ptr<Buffer> ifmBuffer = allocAndFill(device, filename);
181 shared_ptr<Buffer> ofmBuffer = make_shared<Buffer>(device, 128 * 1024);
182 shared_ptr<Inference> inference = make_shared<Inference>(network, ifmBuffer, ofmBuffer);
183 inferences.push_back(inference);
184 }
185
186 cout << "Wait for inferences" << endl;
187
188 int ofmIndex = 0;
189 for (auto &inference: inferences)
190 {
191 inference->wait(timeout);
192
193 string status = inference->failed() ? "failed" : "success";
194 cout << "Inference status: " << status << endl;
195
196 string ofmFilename = ofmArg + "." + to_string(ofmIndex);
197 ofstream ofmStream(ofmFilename, ios::binary);
198 if (!ofmStream.is_open())
199 {
200 cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
201 exit(1);
202 }
203
204 if (!inference->failed())
205 {
206 shared_ptr<Buffer> ofmBuffer = inference->getOfmBuffer();
207
208 cout << "OFM size: " << ofmBuffer->size() << endl;
209
210 if (print)
211 {
212 cout << "OFM data: " << *ofmBuffer << endl;
213 }
214
215 ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
216 }
217
218 ofmIndex++;
219 }
220 }
221 catch (Exception &e)
222 {
223 cerr << "Error: " << e.what() << endl;
224 return 1;
225 }
226
227 return 0;
228}