blob: 34fb07c5733e86a69ed56e4750783c0f61ab0328 [file] [log] [blame]
/*
* SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ethosu.hpp>
#include <uapi/ethosu.h>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <list>
#include <stdio.h>
#include <string>
#include <unistd.h>
#include <utility>
using namespace std;
using namespace EthosU;
namespace {
int64_t defaultTimeout = 60000000000;
void help(const string &exe) {
cerr << "Usage: " << exe << " [ARGS]\n";
cerr << "\n";
cerr << "Arguments:\n";
cerr << " -h --help Print this help message.\n";
cerr << " -n --network File to read network from.\n";
cerr << " --index Network model index, stored in firmware binary.\n";
cerr << " -i --ifm File to read IFM from.\n";
cerr << " -o --ofm File to write IFM to.\n";
cerr << " -P --pmu [0.." << Inference::getMaxPmuEventCounters() << "] eventid.\n";
cerr << " PMU counter to enable followed by eventid, can be passed multiple times.\n";
cerr << " -C --cycles Enable cycle counter for inference.\n";
cerr << " -t --timeout Timeout in nanoseconds (default " << defaultTimeout << ").\n";
cerr << " -p Print OFM.\n";
cerr << endl;
}
void rangeCheck(const int i, const int argc, const string &arg) {
if (i >= argc) {
cerr << "Error: Missing argument to '" << arg << "'" << endl;
exit(1);
}
}
pair<unique_ptr<unsigned char[]>, size_t> getNetworkData(const string &filename) {
ifstream stream(filename, ios::binary);
if (!stream.is_open()) {
cerr << "Error: Failed to open '" << filename << "'" << endl;
exit(1);
}
stream.seekg(0, ios_base::end);
size_t size = stream.tellg();
stream.seekg(0, ios_base::beg);
unique_ptr<unsigned char[]> data = std::make_unique<unsigned char[]>(size);
stream.read(reinterpret_cast<char *>(data.get()), size);
return make_pair(std::move(data), size);
}
shared_ptr<Inference> createInference(Device &device,
shared_ptr<Network> &network,
const string &filename,
const std::vector<uint8_t> &counters,
bool enableCycleCounter) {
// Open IFM file
ifstream stream(filename, ios::binary);
if (!stream.is_open()) {
cerr << "Error: Failed to open '" << filename << "'" << endl;
exit(1);
}
// Get IFM file size
stream.seekg(0, ios_base::end);
size_t size = stream.tellg();
stream.seekg(0, ios_base::beg);
if (size != network->getIfmSize()) {
cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size
<< ", network=" << network->getIfmSize() << endl;
exit(1);
}
// Create IFM buffers
vector<shared_ptr<Buffer>> ifm;
for (auto size : network->getIfmDims()) {
shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
stream.read(buffer->data(), size);
if (!stream) {
cerr << "Error: Failed to read IFM" << endl;
exit(1);
}
ifm.push_back(buffer);
}
// Create OFM buffers
vector<shared_ptr<Buffer>> ofm;
for (auto size : network->getOfmDims()) {
ofm.push_back(make_shared<Buffer>(device, size));
}
return make_shared<Inference>(
network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end(), counters, enableCycleCounter);
}
ostream &operator<<(ostream &os, Buffer &buf) {
char *c = buf.data();
const char *end = c + buf.size();
while (c < end) {
os << hex << setw(2) << static_cast<int>(*c++) << " " << dec;
}
return os;
}
} // namespace
int main(int argc, char *argv[]) {
const string exe = argv[0];
string networkArg;
int networkIndex = -1;
list<string> ifmArg;
vector<uint8_t> enabledCounters(Inference::getMaxPmuEventCounters());
string ofmArg;
int64_t timeout = defaultTimeout;
bool print = false;
bool enableCycleCounter = false;
for (int i = 1; i < argc; ++i) {
const string arg(argv[i]);
if (arg == "-h" || arg == "--help") {
help(exe);
exit(1);
} else if (arg == "--network" || arg == "-n") {
rangeCheck(++i, argc, arg);
networkArg = argv[i];
} else if (arg == "--index") {
rangeCheck(++i, argc, arg);
networkIndex = stoi(argv[i]);
} else if (arg == "--ifm" || arg == "-i") {
rangeCheck(++i, argc, arg);
ifmArg.push_back(argv[i]);
} else if (arg == "--ofm" || arg == "-o") {
rangeCheck(++i, argc, arg);
ofmArg = argv[i];
} else if (arg == "--timeout" || arg == "-t") {
rangeCheck(++i, argc, arg);
timeout = stoll(argv[i]);
} else if (arg == "--pmu" || arg == "-P") {
unsigned pmu = 0, event = 0;
rangeCheck(++i, argc, arg);
pmu = stoi(argv[i]);
rangeCheck(++i, argc, arg);
event = stoi(argv[i]);
if (pmu >= enabledCounters.size()) {
cerr << "PMU out of bounds!" << endl;
help(exe);
exit(1);
}
cout << argv[i] << " -> Enabling " << pmu << " with event " << event << endl;
enabledCounters[pmu] = event;
} else if (arg == "--cycles" || arg == "-C") {
enableCycleCounter = true;
} else if (arg == "-p") {
print = true;
} else {
cerr << "Error: Invalid argument '" << arg << "'" << endl;
help(exe);
exit(1);
}
}
if (networkArg.empty()) {
cerr << "Error: Missing 'network' argument" << endl;
exit(1);
}
if (ifmArg.empty()) {
cerr << "Error: Missing 'ifm' argument" << endl;
exit(1);
}
if (ofmArg.empty()) {
cerr << "Error: Missing 'ofm' argument" << endl;
exit(1);
}
try {
cout << "Driver library version:" << getLibraryVersion() << endl;
Device device;
cout << "Kernel driver version:" << device.getDriverVersion() << endl;
cout << "Send Ping" << endl;
device.ioctl(ETHOSU_IOCTL_PING);
cout << "Send capabilities request" << endl;
Capabilities capabilities = device.capabilities();
cout << "Capabilities:" << endl
<< "\tversion_status:" << unsigned(capabilities.hwId.versionStatus) << endl
<< "\tversion:" << capabilities.hwId.version << endl
<< "\tproduct:" << capabilities.hwId.product << endl
<< "\tarchitecture:" << capabilities.hwId.architecture << endl
<< "\tdriver:" << capabilities.driver << endl
<< "\tmacs_per_cc:" << unsigned(capabilities.hwCfg.macsPerClockCycle) << endl
<< "\tcmd_stream_version:" << unsigned(capabilities.hwCfg.cmdStreamVersion) << endl
<< "\ttype:" << capabilities.hwCfg.type << endl
<< "\tcustom_dma:" << std::boolalpha << capabilities.hwCfg.customDma << endl;
/* Create network */
cout << "Create network" << endl;
shared_ptr<Network> network;
if (networkIndex < 0) {
auto networkData = getNetworkData(networkArg);
network = make_shared<Network>(device, networkData.first.get(), networkData.second);
} else {
network = make_shared<Network>(device, networkIndex);
}
/* Create one inference per IFM */
list<shared_ptr<Inference>> inferences;
for (auto &filename : ifmArg) {
cout << "Create inference" << endl;
inferences.push_back(createInference(device, network, filename, enabledCounters, enableCycleCounter));
}
cout << "Wait for inferences" << endl;
int ofmIndex = 0;
for (auto &inference : inferences) {
cout << "Inference status: " << inference->status() << endl;
/* make sure the wait completes ok */
try {
cout << "Wait for inference" << endl;
bool timedout = inference->wait(timeout);
if (timedout) {
cout << "Inference timed out, cancelling it" << endl;
bool aborted = inference->cancel();
if (!aborted || inference->status() != InferenceStatus::ABORTED) {
cout << "Inference cancellation failed" << endl;
}
}
} catch (std::exception &e) {
cout << "Failed to wait for or to cancel inference: " << e.what() << endl;
exit(1);
}
cout << "Inference status: " << inference->status() << endl;
if (inference->status() == InferenceStatus::OK) {
string ofmFilename = ofmArg + "." + to_string(ofmIndex);
ofstream ofmStream(ofmFilename, ios::binary);
if (!ofmStream.is_open()) {
cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
exit(1);
}
/* The inference completed and has ok status */
for (auto &ofmBuffer : inference->getOfmBuffers()) {
cout << "OFM size: " << ofmBuffer->size() << endl;
if (print) {
cout << "OFM data: " << *ofmBuffer << endl;
}
ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
ofmStream.flush();
/* Read out PMU counters if configured */
if (std::count(enabledCounters.begin(), enabledCounters.end(), 0) <
Inference::getMaxPmuEventCounters()) {
const std::vector<uint64_t> pmus = inference->getPmuCounters();
cout << "PMUs : [";
for (auto p : pmus) {
cout << " " << p;
}
cout << " ]" << endl;
}
if (enableCycleCounter)
cout << "Cycle counter: " << inference->getCycleCounter() << endl;
}
ofmIndex++;
}
} catch (Exception &e) {
cerr << "Error: " << e.what() << endl;
return 1;
}
return 0;
}