Add thin abstraction layer for network sockets

This makes SocketProfilingConnection and GatordMock work on Windows as
well as Linux

Change-Id: I4b10c079b653a1c3f61eb20694e5b5f8a6f5fdfb
Signed-off-by: Robert Hughes <robert.hughes@arm.com>
diff --git a/Android.mk b/Android.mk
index bfaee44..8f348d9 100644
--- a/Android.mk
+++ b/Android.mk
@@ -117,6 +117,7 @@
         src/armnnUtils/Permute.cpp \
         src/armnnUtils/TensorUtils.cpp \
         src/armnnUtils/VerificationHelpers.cpp \
+        src/armnnUtils/NetworkSockets.cpp \
         src/armnn/layers/AbsLayer.cpp \
         src/armnn/layers/ActivationLayer.cpp \
         src/armnn/layers/AdditionLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4d54137..e39c2b8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -65,6 +65,8 @@
     src/armnnUtils/QuantizeHelper.hpp
     src/armnnUtils/TensorIOUtils.hpp
     src/armnnUtils/TensorUtils.cpp
+    src/armnnUtils/NetworkSockets.hpp
+    src/armnnUtils/NetworkSockets.cpp
     )
 
 add_library_ex(armnnUtils STATIC ${armnnUtils_sources})
@@ -533,6 +535,9 @@
 target_link_libraries(armnn armnnUtils)
 
 target_link_libraries(armnn ${CMAKE_DL_LIBS})
+if ("${CMAKE_SYSTEM_NAME}" STREQUAL Windows) 
+    target_link_libraries(armnn Ws2_32.lib)
+endif()
 
 install(TARGETS armnn
         LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
@@ -963,8 +968,10 @@
     include_directories(${Boost_INCLUDE_DIRS} tests/profiling/timelineDecoder)
 
     add_library_ex(gatordMockService STATIC ${gatord_mock_sources})
+    target_include_directories(gatordMockService PRIVATE src/armnnUtils)
 
     add_executable_ex(GatordMock tests/profiling/gatordmock/GatordMockMain.cpp)
+    target_include_directories(GatordMock PRIVATE src/armnnUtils)
 
     target_link_libraries(GatordMock
         armnn
diff --git a/src/armnnUtils/NetworkSockets.cpp b/src/armnnUtils/NetworkSockets.cpp
new file mode 100644
index 0000000..cc28a90
--- /dev/null
+++ b/src/armnnUtils/NetworkSockets.cpp
@@ -0,0 +1,99 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NetworkSockets.hpp"
+
+#if defined(__unix__)
+#include <unistd.h>
+#include <fcntl.h>
+#endif
+
+namespace armnnUtils
+{
+namespace Sockets
+{
+
+bool Initialize()
+{
+#if defined(__unix__)
+    return true;
+#elif defined(_MSC_VER)
+    WSADATA wsaData;
+    return WSAStartup(MAKEWORD(2, 2), &wsaData) == 0;
+#endif
+}
+
+int Close(Socket s)
+{
+#if defined(__unix__)
+    return close(s);
+#elif defined(_MSC_VER)
+    return closesocket(s);
+#endif
+}
+
+
+bool SetNonBlocking(Socket s)
+{
+#if defined(__unix__)
+    const int currentFlags = fcntl(s, F_GETFL);
+    return fcntl(s, F_SETFL, currentFlags | O_NONBLOCK) == 0;
+#elif defined(_MSC_VER)
+    u_long mode = 1;
+    return ioctlsocket(s, FIONBIO, &mode) == 0;
+#endif
+}
+
+
+long Write(Socket s, const void* buf, size_t len)
+{
+#if defined(__unix__)
+    return write(s, buf, len);
+#elif defined(_MSC_VER)
+    return send(s, static_cast<const char*>(buf), len, 0);
+#endif
+}
+
+
+long Read(Socket s, void* buf, size_t len)
+{
+#if defined(__unix__)
+    return read(s, buf, len);
+#elif defined(_MSC_VER)
+    return recv(s, static_cast<char*>(buf), len, 0);
+#endif
+}
+
+int Ioctl(Socket s, unsigned long cmd, void* arg)
+{
+#if defined(__unix__)
+    return ioctl(s, cmd, arg);
+#elif defined(_MSC_VER)
+    return ioctlsocket(s, cmd, static_cast<u_long*>(arg));
+#endif
+}
+
+
+int Poll(PollFd* fds, size_t numFds, int timeout)
+{
+#if defined(__unix__)
+    return poll(fds, numFds, timeout);
+#elif defined(_MSC_VER)
+    return WSAPoll(fds, numFds, timeout);
+#endif
+}
+
+
+armnnUtils::Sockets::Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags)
+{
+#if defined(__unix__)
+    return accept4(s, addr, addrlen, flags);
+#elif defined(_MSC_VER)
+    return accept(s, addr, reinterpret_cast<int*>(addrlen));
+#endif
+}
+
+}
+}
diff --git a/src/armnnUtils/NetworkSockets.hpp b/src/armnnUtils/NetworkSockets.hpp
new file mode 100644
index 0000000..9e47707
--- /dev/null
+++ b/src/armnnUtils/NetworkSockets.hpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+// This file (along with its corresponding .cpp) defines a very thin platform abstraction layer for the use of
+// networking sockets. Thankfully the underlying APIs on Windows and Linux are very similar so not much conversion
+// is needed (typically just forwarding the parameters to a differently named function).
+// Some of the APIs are in fact completely identical and so no forwarding function is needed.
+
+#pragma once
+
+#if defined(__unix__)
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#elif defined(_MSC_VER)
+#include <winsock2.h>
+#include <afunix.h>
+#endif
+
+namespace armnnUtils
+{
+namespace Sockets
+{
+
+#if defined(__unix__)
+
+using Socket = int;
+using PollFd = pollfd;
+
+#elif defined(_MSC_VER)
+
+using Socket = SOCKET;
+using PollFd = WSAPOLLFD;
+#define SOCK_CLOEXEC 0
+
+#endif
+
+/// Performs any required one-time setup.
+bool Initialize();
+
+int Close(Socket s);
+
+bool SetNonBlocking(Socket s);
+
+long Write(Socket s, const void* buf, size_t len);
+
+long Read(Socket s, void* buf, size_t len);
+
+int Ioctl(Socket s, unsigned long cmd, void* arg);
+
+int Poll(PollFd* fds, size_t numFds, int timeout);
+
+Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags);
+
+}
+}
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index c78c182..4bbbc29 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -7,11 +7,10 @@
 
 #include <cerrno>
 #include <fcntl.h>
-#include <sys/ioctl.h>
-#include <sys/socket.h>
-#include <sys/un.h>
 #include <string>
 
+using namespace armnnUtils;
+
 namespace armnn
 {
 namespace profiling
@@ -19,6 +18,7 @@
 
 SocketProfilingConnection::SocketProfilingConnection()
 {
+    Sockets::Initialize();
     memset(m_Socket, 0, sizeof(m_Socket));
     // Note: we're using Linux specific SOCK_CLOEXEC flag.
     m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
@@ -28,7 +28,7 @@
     }
 
     // Connect to the named unix domain socket.
-    struct sockaddr_un server{};
+    sockaddr_un server{};
     memset(&server, 0, sizeof(sockaddr_un));
     // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
     memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
@@ -43,8 +43,7 @@
     m_Socket[0].events = POLLIN;
 
     // Make the socket non blocking.
-    const int currentFlags = fcntl(m_Socket[0].fd, F_GETFL);
-    if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK))
+    if (!Sockets::SetNonBlocking(m_Socket[0].fd))
     {
         Close();
         throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
@@ -58,7 +57,7 @@
 
 void SocketProfilingConnection::Close()
 {
-    if (close(m_Socket[0].fd) != 0)
+    if (Sockets::Close(m_Socket[0].fd) != 0)
     {
         throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
     }
@@ -73,14 +72,14 @@
         return false;
     }
 
-    return write(m_Socket[0].fd, buffer, length) != -1;
+    return Sockets::Write(m_Socket[0].fd, buffer, length) != -1;
 }
 
 Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
 {
     // Is there currently at least a header worth of data waiting to be read?
     int bytes_available = 0;
-    ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
+    Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
     if (bytes_available >= 8)
     {
         // Yes there is. Read it:
@@ -88,7 +87,7 @@
     }
 
     // Poll for data on the socket or until timeout occurs
-    int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
+    int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
 
     switch (pollResult)
     {
@@ -136,7 +135,7 @@
 Packet SocketProfilingConnection::ReceivePacket()
 {
     char header[8] = {};
-    ssize_t receiveResult = recv(m_Socket[0].fd, &header, sizeof(header), 0);
+    long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header));
     // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
     switch( receiveResult )
     {
@@ -168,10 +167,10 @@
     if (dataLength > 0)
     {
         packetData = std::make_unique<unsigned char[]>(dataLength);
-        ssize_t receivedLength = recv(m_Socket[0].fd, packetData.get(), dataLength, 0);
+        long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength);
         if (receivedLength < 0)
         {
-            throw armnn::RuntimeException(std::string("Error occured on recv: ") + strerror(errno));
+            throw armnn::RuntimeException(std::string("Error occurred on recv: ") + strerror(errno));
         }
         if (dataLength != static_cast<uint32_t>(receivedLength))
         {
diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp
index 5fb02bb..05c7130 100644
--- a/src/profiling/SocketProfilingConnection.hpp
+++ b/src/profiling/SocketProfilingConnection.hpp
@@ -5,8 +5,8 @@
 
 #include "IProfilingConnection.hpp"
 
-#include <poll.h>
 #include <Runtime.hpp>
+#include <NetworkSockets.hpp>
 
 #pragma once
 
@@ -31,7 +31,7 @@
 
     // To indicate we want to use an abstract UDS ensure the first character of the address is 0.
     const char* m_GatorNamespace = "\0gatord_namespace";
-    struct pollfd m_Socket[1]{};
+    armnnUtils::Sockets::PollFd m_Socket[1]{};
 };
 
 } // namespace profiling
diff --git a/tests/profiling/gatordmock/CommandFileParser.cpp b/tests/profiling/gatordmock/CommandFileParser.cpp
index 4a8a19b..7c746f1 100644
--- a/tests/profiling/gatordmock/CommandFileParser.cpp
+++ b/tests/profiling/gatordmock/CommandFileParser.cpp
@@ -54,7 +54,7 @@
             // 500000       polling period in micro seconds
             // 1 2 5 10     counter list
 
-            uint period = static_cast<uint>(std::stoul(tokens[1]));
+            uint32_t period = static_cast<uint32_t>(std::stoul(tokens[1]));
 
             std::vector<uint16_t> counters;
 
@@ -73,7 +73,7 @@
             // WAIT         command
             // 11000000     timeout period in micro seconds
 
-            uint timeout = static_cast<uint>(std::stoul(tokens[1]));
+            uint32_t timeout = static_cast<uint32_t>(std::stoul(tokens[1]));
 
             mockService.WaitCommand(timeout);
         }
diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp
index 529ef06..c521196 100644
--- a/tests/profiling/gatordmock/GatordMockService.cpp
+++ b/tests/profiling/gatordmock/GatordMockService.cpp
@@ -8,17 +8,15 @@
 #include <CommandHandlerRegistry.hpp>
 #include <PacketVersionResolver.hpp>
 #include <ProfilingUtils.hpp>
+#include <NetworkSockets.hpp>
 
 #include <cerrno>
 #include <fcntl.h>
 #include <iomanip>
 #include <iostream>
-#include <poll.h>
 #include <string>
-#include <sys/ioctl.h>
-#include <sys/socket.h>
-#include <sys/un.h>
-#include <unistd.h>
+
+using namespace armnnUtils;
 
 namespace armnn
 {
@@ -28,6 +26,7 @@
 
 bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
 {
+    Sockets::Initialize();
     m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
     if (-1 == m_ListeningSocket)
     {
@@ -56,9 +55,9 @@
     return true;
 }
 
-int GatordMockService::BlockForOneClient()
+Sockets::Socket GatordMockService::BlockForOneClient()
 {
-    m_ClientConnection = accept4(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+    m_ClientConnection = Sockets::Accept(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
     if (-1 == m_ClientConnection)
     {
         std::cerr << ": Failure when waiting for a client connection: " << strerror(errno) << std::endl;
@@ -112,13 +111,14 @@
     // Remember we already read the pipe magic 4 bytes.
     uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4;
     // Read the entire packet.
-    uint8_t packetData[metaDataLength];
-    if (metaDataLength != boost::numeric_cast<uint32_t>(read(m_ClientConnection, &packetData, metaDataLength)))
+    std::vector<uint8_t> packetData(metaDataLength);
+    if (metaDataLength !=
+        boost::numeric_cast<uint32_t>(Sockets::Read(m_ClientConnection, packetData.data(), metaDataLength)))
     {
         std::cerr << ": Protocol read error. Data length mismatch." << std::endl;
         return false;
     }
-    EchoPacket(PacketDirection::ReceivedData, packetData, metaDataLength);
+    EchoPacket(PacketDirection::ReceivedData, packetData.data(), metaDataLength);
     m_StreamMetaDataVersion    = ToUint32(&packetData[0], m_Endianness);
     m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness);
     m_StreamMetaDataPid        = ToUint32(&packetData[8], m_Endianness);
@@ -153,10 +153,9 @@
         std::cout << "Launching receiving thread." << std::endl;
     }
     // At this point we want to make the socket non blocking.
-    const int currentFlags = fcntl(m_ClientConnection, F_GETFL);
-    if (0 != fcntl(m_ClientConnection, F_SETFL, currentFlags | O_NONBLOCK))
+    if (!Sockets::SetNonBlocking(m_ClientConnection))
     {
-        close(m_ClientConnection);
+        Sockets::Close(m_ClientConnection);
         std::cerr << "Failed to set socket as non blocking: " << strerror(errno) << std::endl;
         return false;
     }
@@ -212,13 +211,13 @@
     // should deal with it.
 }
 
-void GatordMockService::WaitCommand(uint timeout)
+void GatordMockService::WaitCommand(uint32_t timeout)
 {
     // Wait for a maximum of timeout microseconds or if the receive thread has closed.
     // There is a certain level of rounding involved in this timing.
-    uint iterations = timeout / 1000;
+    uint32_t iterations = timeout / 1000;
     std::cout << std::dec << "Wait command with timeout of " << timeout << " iterations =  " << iterations << std::endl;
-    uint count = 0;
+    uint32_t count = 0;
     while ((this->ReceiveThreadRunning() && (count < iterations)))
     {
         std::this_thread::sleep_for(std::chrono::microseconds(1000));
@@ -261,7 +260,7 @@
 {
     // Is there currently more than a headers worth of data waiting to be read?
     int bytes_available;
-    ioctl(m_ClientConnection, FIONREAD, &bytes_available);
+    Sockets::Ioctl(m_ClientConnection, FIONREAD, &bytes_available);
     if (bytes_available > 8)
     {
         // Yes there is. Read it:
@@ -272,7 +271,7 @@
         // No there's not. Poll for more data.
         struct pollfd pollingFd[1]{};
         pollingFd[0].fd = m_ClientConnection;
-        int pollResult  = poll(pollingFd, 1, static_cast<int>(timeoutMs));
+        int pollResult  = Sockets::Poll(pollingFd, 1, static_cast<int>(timeoutMs));
 
         switch (pollResult)
         {
@@ -362,16 +361,16 @@
     header[0] = packetFamily << 26 | packetId << 16;
     header[1] = dataLength;
     // Add the header to the packet.
-    uint8_t packet[8 + dataLength];
-    InsertU32(header[0], packet, m_Endianness);
-    InsertU32(header[1], packet + 4, m_Endianness);
+    std::vector<uint8_t> packet(8 + dataLength);
+    InsertU32(header[0], packet.data(), m_Endianness);
+    InsertU32(header[1], packet.data() + 4, m_Endianness);
     // And the rest of the data if there is any.
     if (dataLength > 0)
     {
-        memcpy((packet + 8), data, dataLength);
+        memcpy((packet.data() + 8), data, dataLength);
     }
-    EchoPacket(PacketDirection::Sending, packet, sizeof(packet));
-    if (-1 == write(m_ClientConnection, packet, sizeof(packet)))
+    EchoPacket(PacketDirection::Sending, packet.data(), packet.size());
+    if (-1 == Sockets::Write(m_ClientConnection, packet.data(), packet.size()))
     {
         std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl;
         return false;
@@ -396,10 +395,10 @@
 bool GatordMockService::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength)
 {
     // This is a blocking read until either expectedLength has been received or an error is detected.
-    ssize_t totalBytesRead = 0;
+    long totalBytesRead = 0;
     while (boost::numeric_cast<uint32_t>(totalBytesRead) < expectedLength)
     {
-        ssize_t bytesRead = recv(m_ClientConnection, packetData, expectedLength, 0);
+        long bytesRead = Sockets::Read(m_ClientConnection, packetData, expectedLength);
         if (bytesRead < 0)
         {
             std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl;
diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp
index c3afc33..f91e902 100644
--- a/tests/profiling/gatordmock/GatordMockService.hpp
+++ b/tests/profiling/gatordmock/GatordMockService.hpp
@@ -7,6 +7,7 @@
 
 #include <CommandHandlerRegistry.hpp>
 #include <Packet.hpp>
+#include <NetworkSockets.hpp>
 
 #include <atomic>
 #include <string>
@@ -49,8 +50,8 @@
     ~GatordMockService()
     {
         // We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens.
-        close(m_ClientConnection);
-        close(m_ListeningSocket);
+        armnnUtils::Sockets::Close(m_ClientConnection);
+        armnnUtils::Sockets::Close(m_ListeningSocket);
     }
 
     /// Establish the Unix domain socket and set it to listen for connections.
@@ -60,7 +61,7 @@
 
     /// Block waiting to accept one client to connect to the UDS.
     /// @return the file descriptor of the client connection.
-    int BlockForOneClient();
+    armnnUtils::Sockets::Socket BlockForOneClient();
 
     /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this
     /// packet differs from others as we need to determine endianness.
@@ -147,8 +148,8 @@
     armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry;
 
     bool m_EchoPackets;
-    int m_ListeningSocket;
-    int m_ClientConnection;
+    armnnUtils::Sockets::Socket m_ListeningSocket;
+    armnnUtils::Sockets::Socket m_ClientConnection;
     std::thread m_ListeningThread;
     std::atomic<bool> m_CloseReceivingThread;
 };
diff --git a/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp b/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp
index bdceca6..78b1300 100644
--- a/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp
+++ b/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp
@@ -122,7 +122,7 @@
     event.m_TimeStamp = profiling::ReadUint64(data, offset);
     offset += uint64_t_size;
 
-    event.m_ThreadId = new u_int8_t[threadId_size];
+    event.m_ThreadId = new uint8_t[threadId_size];
     profiling::ReadBytes(data, offset, threadId_size, event.m_ThreadId);
     offset += threadId_size;