blob: c3a310c5cb85f1b1598f93bd00514151542a48ec [file] [log] [blame]
Kristofer Jonsson116a6352020-08-20 17:25:23 +02001/*
2 * Copyright (c) 2020 Arm Limited. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19#pragma once
20
21#include <uapi/ethosu.h>
22
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020023#include <algorithm>
Kristofer Jonsson116a6352020-08-20 17:25:23 +020024#include <memory>
25#include <string>
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020026#include <vector>
Kristofer Jonsson116a6352020-08-20 17:25:23 +020027
28namespace EthosU
29{
30
31class Exception :
32 public std::exception
33{
34public:
35 Exception(const char *msg);
36 virtual ~Exception() throw();
37 virtual const char *what() const throw();
38
39private:
40 std::string msg;
41};
42
43class Device
44{
45public:
46 Device(const char *device = "/dev/ethosu0");
47 virtual ~Device();
48
49 int ioctl(unsigned long cmd, void *data = nullptr);
50
51private:
52 int fd;
53};
54
55class Buffer
56{
57public:
58 Buffer(Device &device, const size_t capacity);
59 virtual ~Buffer();
60
61 size_t capacity() const;
62 void clear();
63 char *data();
64 void resize(size_t size, size_t offset = 0);
65 size_t offset() const;
66 size_t size() const;
67
68 int getFd() const;
69
70private:
71 int fd;
72 char *dataPtr;
73 const size_t dataCapacity;
74};
75
76class Network
77{
78public:
79 Network(Device &device, std::shared_ptr<Buffer> &buffer);
80 virtual ~Network();
81
82 int ioctl(unsigned long cmd, void *data = nullptr);
83 std::shared_ptr<Buffer> getBuffer();
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020084 const std::vector<size_t> &getIfmDims() const;
85 size_t getIfmSize() const;
86 const std::vector<size_t> &getOfmDims() const;
87 size_t getOfmSize() const;
Kristofer Jonsson116a6352020-08-20 17:25:23 +020088
89private:
90 int fd;
91 std::shared_ptr<Buffer> buffer;
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020092 std::vector<size_t> ifmDims;
93 std::vector<size_t> ofmDims;
Kristofer Jonsson116a6352020-08-20 17:25:23 +020094};
95
96class Inference
97{
98public:
Kristofer Jonssonb74492c2020-09-10 13:26:01 +020099 template <typename T>
100 Inference(std::shared_ptr<Network> &network, const T &ifmBegin, const T &ifmEnd, const T &ofmBegin, const T &ofmEnd) :
101 network(network)
102 {
103 std::copy(ifmBegin, ifmEnd, std::back_inserter(ifmBuffers));
104 std::copy(ofmBegin, ofmEnd, std::back_inserter(ofmBuffers));
105 create();
106 }
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200107 virtual ~Inference();
108
109 void wait(int timeoutSec = -1);
110 bool failed();
111 int getFd();
112 std::shared_ptr<Network> getNetwork();
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200113 std::vector<std::shared_ptr<Buffer>> &getIfmBuffers();
114 std::vector<std::shared_ptr<Buffer>> &getOfmBuffers();
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200115
116private:
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200117 void create();
118
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200119 int fd;
120 std::shared_ptr<Network> network;
Kristofer Jonssonb74492c2020-09-10 13:26:01 +0200121 std::vector<std::shared_ptr<Buffer>> ifmBuffers;
122 std::vector<std::shared_ptr<Buffer>> ofmBuffers;
Kristofer Jonsson116a6352020-08-20 17:25:23 +0200123};
124
125}