blob: 47fc62f7f0a359315aee8e169b46e3e71176d6b2 [file] [log] [blame]
Teresa Charlin9bab4962019-09-06 12:28:35 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SocketProfilingConnection.hpp"
7
8#include <fcntl.h>
9#include <sys/socket.h>
10#include <sys/un.h>
11#include <cerrno>
12#include <string>
13
14namespace armnn
15{
16namespace profiling
17{
18
19SocketProfilingConnection::SocketProfilingConnection()
20{
21 memset(m_Socket, 0, sizeof(m_Socket));
22 // Note: we're using Linux specific SOCK_CLOEXEC flag.
23 m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
24 if (m_Socket[0].fd == -1)
25 {
Matteo Martincigh24e8f922019-09-19 11:57:46 +010026 throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno));
Teresa Charlin9bab4962019-09-06 12:28:35 +010027 }
28
29 // Connect to the named unix domain socket.
30 struct sockaddr_un server{};
31 memset(&server, 0, sizeof(sockaddr_un));
32 // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
33 memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
34 server.sun_family = AF_UNIX;
35 if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
36 {
37 close(m_Socket[0].fd);
Matteo Martincigh24e8f922019-09-19 11:57:46 +010038 throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno));
Teresa Charlin9bab4962019-09-06 12:28:35 +010039 }
40
41 // Our socket will only be interested in polling reads.
42 m_Socket[0].events = POLLIN;
43
44 // Make the socket non blocking.
45 const int currentFlags = fcntl(m_Socket[0].fd, F_GETFL);
46 if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK))
47 {
48 close(m_Socket[0].fd);
Matteo Martincigh24e8f922019-09-19 11:57:46 +010049 throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
Teresa Charlin9bab4962019-09-06 12:28:35 +010050 }
51}
52
53bool SocketProfilingConnection::IsOpen()
54{
Matteo Martincigh24e8f922019-09-19 11:57:46 +010055 return m_Socket[0].fd > 0;
Teresa Charlin9bab4962019-09-06 12:28:35 +010056}
57
58void SocketProfilingConnection::Close()
59{
Matteo Martincigh24e8f922019-09-19 11:57:46 +010060 if (close(m_Socket[0].fd) != 0)
FinnWilliamsArma0c78712019-09-16 12:06:47 +010061 {
Matteo Martincigh24e8f922019-09-19 11:57:46 +010062 throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
FinnWilliamsArma0c78712019-09-16 12:06:47 +010063 }
Matteo Martincigh24e8f922019-09-19 11:57:46 +010064
65 memset(m_Socket, 0, sizeof(m_Socket));
Teresa Charlin9bab4962019-09-06 12:28:35 +010066}
67
Matteo Martincigh24e8f922019-09-19 11:57:46 +010068bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
Teresa Charlin9bab4962019-09-06 12:28:35 +010069{
Matteo Martincigh24e8f922019-09-19 11:57:46 +010070 if (buffer == nullptr || length == 0)
FinnWilliamsArma0c78712019-09-16 12:06:47 +010071 {
72 return false;
73 }
Matteo Martincigh24e8f922019-09-19 11:57:46 +010074
75 return write(m_Socket[0].fd, buffer, length) != -1;
Teresa Charlin9bab4962019-09-06 12:28:35 +010076}
77
78Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
79{
Matteo Martincigh24e8f922019-09-19 11:57:46 +010080 // Poll for data on the socket or until timeout occurs
FinnWilliamsArma0c78712019-09-16 12:06:47 +010081 int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
Matteo Martincigh24e8f922019-09-19 11:57:46 +010082
83 switch (pollResult)
FinnWilliamsArma0c78712019-09-16 12:06:47 +010084 {
Matteo Martincigh24e8f922019-09-19 11:57:46 +010085 case -1: // Error
86 throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno));
87
88 case 0: // Timeout
89 throw armnn::RuntimeException("Timeout while reading from socket");
90
91 default: // Normal poll return but it could still contain an error signal
92
93 // Check if the socket reported an error
FinnWilliamsArma0c78712019-09-16 12:06:47 +010094 if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
95 {
Matteo Martincigh24e8f922019-09-19 11:57:46 +010096 throw armnn::Exception(std::string("Socket 0 reported an error: ") + strerror(errno));
FinnWilliamsArma0c78712019-09-16 12:06:47 +010097 }
FinnWilliamsArma0c78712019-09-16 12:06:47 +010098
Matteo Martincigh24e8f922019-09-19 11:57:46 +010099 // Check if there is data to read
100 if (!(m_Socket[0].revents & (POLLIN)))
FinnWilliamsArma0c78712019-09-16 12:06:47 +0100101 {
Matteo Martincigh24e8f922019-09-19 11:57:46 +0100102 // No data to read from the socket. Silently ignore and continue
103 return Packet();
FinnWilliamsArma0c78712019-09-16 12:06:47 +0100104 }
Matteo Martincigh24e8f922019-09-19 11:57:46 +0100105
106 // There is data to read, read the header first
107 char header[8] = {};
108 if (8 != recv(m_Socket[0].fd, &header, sizeof(header), 0))
109 {
110 // What do we do here if there's not a valid 8 byte header to read?
111 throw armnn::RuntimeException("The received packet did not contains a valid MIPE header");
112 }
113
114 // stream_metadata_identifier is the first 4 bytes
115 uint32_t metadataIdentifier = 0;
116 std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
117
118 // data_length is the next 4 bytes
119 uint32_t dataLength = 0;
120 std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
121
122 std::unique_ptr<char[]> packetData;
123 if (dataLength > 0)
124 {
125 packetData = std::make_unique<char[]>(dataLength);
126 }
127
128 if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0))
129 {
130 // What do we do here if we can't read in a full packet?
131 throw armnn::RuntimeException("Invalid MIPE packet");
132 }
133
134 return Packet(metadataIdentifier, dataLength, packetData);
FinnWilliamsArma0c78712019-09-16 12:06:47 +0100135 }
Teresa Charlin9bab4962019-09-06 12:28:35 +0100136}
137
138} // namespace profiling
139} // namespace armnn