IVGCVSW-3937 Improve the Connection Acknowledged Handler

 * The Connection Acknowledged Handler should report an error
   is it's called while in a wrong state
 * Stopping the threads in the ProfilingService before having
   to start them again
 * Updated the unit tests to check the changes
 * Removed unnecessary Packet.cpp file
 * Fixed memory leak

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I8c4d33b4d97994df86fe6c9f8c659f880ec64c16
diff --git a/Android.mk b/Android.mk
index fcbab68..108e011 100644
--- a/Android.mk
+++ b/Android.mk
@@ -179,7 +179,6 @@
         src/profiling/CounterDirectory.cpp \
         src/profiling/Holder.cpp \
         src/profiling/PacketBuffer.cpp \
-        src/profiling/Packet.cpp \
         src/profiling/PacketVersionResolver.cpp \
         src/profiling/PeriodicCounterCapture.cpp \
         src/profiling/PeriodicCounterSelectionCommandHandler.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3b27d05..a4c8fc9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -451,7 +451,6 @@
     src/profiling/IPeriodicCounterCapture.hpp
     src/profiling/IProfilingConnection.hpp
     src/profiling/IProfilingConnectionFactory.hpp
-    src/profiling/Packet.cpp
     src/profiling/Packet.hpp
     src/profiling/PacketBuffer.cpp
     src/profiling/PacketBuffer.hpp
@@ -599,7 +598,9 @@
         src/profiling/test/BufferTests.cpp
         src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
         src/profiling/test/ProfilingTests.cpp
+        src/profiling/test/ProfilingTests.hpp
         src/profiling/test/SendCounterPacketTests.cpp
+        src/profiling/test/SendCounterPacketTests.hpp
         src/profiling/test/TimelinePacketTests.cpp
         )
 
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp
index 86fa257..cc68dcf 100644
--- a/src/profiling/CommandHandler.cpp
+++ b/src/profiling/CommandHandler.cpp
@@ -5,6 +5,8 @@
 
 #include "CommandHandler.hpp"
 
+#include <boost/log/trivial.hpp>
+
 namespace armnn
 {
 
@@ -39,7 +41,14 @@
     {
         try
         {
-            Packet packet = profilingConnection.ReadPacket(m_Timeout);
+            Packet packet = profilingConnection.ReadPacket(m_Timeout.load());
+
+            if (packet.IsEmpty())
+            {
+                // Nothing to do, continue
+                continue;
+            }
+
             Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
 
             CommandHandlerFunctor* commandHandlerFunctor =
@@ -49,19 +58,15 @@
         }
         catch (const armnn::TimeoutException&)
         {
-            if (m_StopAfterTimeout)
+            if (m_StopAfterTimeout.load())
             {
-                m_KeepRunning.store(false, std::memory_order_relaxed);
+                m_KeepRunning.store(false);
             }
         }
         catch (const Exception& e)
         {
-            // Log the error
-            BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: "
-                                       << e.what();
-
-            // Might want to differentiate the errors more
-            m_KeepRunning.store(false);
+            // Log the error and continue
+            BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: " << e.what() << std::endl;
         }
     }
     while (m_KeepRunning.load());
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
index f90b601..9d2d1a2 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
@@ -7,6 +7,8 @@
 
 #include <armnn/Exceptions.hpp>
 
+#include <boost/format.hpp>
+
 namespace armnn
 {
 
@@ -15,15 +17,34 @@
 
 void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
 {
-    if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
+    ProfilingState currentState = m_StateMachine.GetCurrentState();
+    switch (currentState)
     {
-        throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ")
-                                              + std::to_string(packet.GetPacketFamily())
-                                              + " id = " + std::to_string(packet.GetPacketId()));
-    }
+    case ProfilingState::Uninitialised:
+    case ProfilingState::NotConnected:
+        throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an "
+                                                        "wrong state: %1%")
+                                          % GetProfilingStateName(currentState)));
+    case ProfilingState::WaitingForAck:
+        // Process the packet
+        if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
+        {
+            throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 1 but "
+                                                                           "received family = %1%, id = %2%")
+                                                  % packet.GetPacketFamily()
+                                                  % packet.GetPacketId()));
+        }
 
-    // Once a Connection Acknowledged packet has been received, move to the Active state immediately
-    m_StateMachine.TransitionToState(ProfilingState::Active);
+        // Once a Connection Acknowledged packet has been received, move to the Active state immediately
+        m_StateMachine.TransitionToState(ProfilingState::Active);
+
+        break;
+    case ProfilingState::Active:
+        return; // NOP
+    default:
+        throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+                                          % static_cast<int>(currentState)));
+    }
 }
 
 } // namespace profiling
diff --git a/src/profiling/Packet.cpp b/src/profiling/Packet.cpp
deleted file mode 100644
index 4cfa42b..0000000
--- a/src/profiling/Packet.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Packet.hpp"
-
-namespace armnn
-{
-
-namespace profiling
-{
-
-std::uint32_t Packet::GetHeader() const
-{
-    return m_Header;
-}
-
-std::uint32_t Packet::GetPacketFamily() const
-{
-    return m_PacketFamily;
-}
-
-std::uint32_t Packet::GetPacketId() const
-{
-    return m_PacketId;
-}
-
-std::uint32_t Packet::GetLength() const
-{
-    return m_Length;
-}
-
-const char* const Packet::GetData() const
-{
-    return m_Data.get();
-}
-
-std::uint32_t Packet::GetPacketClass() const
-{
-    return (m_PacketId >> 3);
-}
-
-std::uint32_t Packet::GetPacketType() const
-{
-    return (m_PacketId & 7);
-}
-
-} // namespace profiling
-
-} // namespace armnn
diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp
index 2aae14b..fae368b 100644
--- a/src/profiling/Packet.hpp
+++ b/src/profiling/Packet.hpp
@@ -2,11 +2,12 @@
 // Copyright © 2017 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
+
 #pragma once
 
 #include <armnn/Exceptions.hpp>
 
-#include <boost/log/trivial.hpp>
+#include <memory>
 
 namespace armnn
 {
@@ -46,26 +47,32 @@
         }
     }
 
-    Packet(Packet&& other) :
-           m_Header(other.m_Header),
-           m_PacketFamily(other.m_PacketFamily),
-           m_PacketId(other.m_PacketId),
-           m_Length(other.m_Length),
-           m_Data(std::move(other.m_Data))
-    {}
+    Packet(Packet&& other)
+        : m_Header(other.m_Header)
+        , m_PacketFamily(other.m_PacketFamily)
+        , m_PacketId(other.m_PacketId)
+        , m_Length(other.m_Length)
+        , m_Data(std::move(other.m_Data))
+    {
+        other.m_Header = 0;
+        other.m_PacketFamily = 0;
+        other.m_PacketId = 0;
+        other.m_Length = 0;
+    }
+
+    ~Packet() = default;
 
     Packet(const Packet& other) = delete;
     Packet& operator=(const Packet&) = delete;
     Packet& operator=(Packet&&) = default;
 
-    uint32_t GetHeader() const;
-    uint32_t GetPacketFamily() const;
-    uint32_t GetPacketId() const;
-    uint32_t GetLength() const;
-    const char* const GetData() const;
-
-    uint32_t GetPacketClass() const;
-    uint32_t GetPacketType() const;
+    uint32_t GetHeader() const        { return m_Header;        }
+    uint32_t GetPacketFamily() const  { return m_PacketFamily;  }
+    uint32_t GetPacketId() const      { return m_PacketId;      }
+    uint32_t GetPacketClass() const   { return m_PacketId >> 3; }
+    uint32_t GetPacketType() const    { return m_PacketId & 7;  }
+    uint32_t GetLength() const        { return m_Length;        }
+    const char* const GetData() const { return m_Data.get();    }
 
     bool IsEmpty() { return m_Header == 0 && m_Length == 0; }
 
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index 19cf9cb..693f833 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -47,7 +47,11 @@
         m_StateMachine.TransitionToState(ProfilingState::NotConnected);
         break;
     case ProfilingState::NotConnected:
-        BOOST_ASSERT(m_ProfilingConnectionFactory);
+        // Stop the command thread (if running)
+        m_CommandHandler.Stop();
+
+        // Stop the send thread (if running)
+        m_SendCounterPacket.Stop(false);
 
         // Reset any existing profiling connection
         m_ProfilingConnection.reset();
@@ -55,13 +59,13 @@
         try
         {
             // Setup the profiling connection
-            //m_ProfilingConnection = m_ProfilingConnectionFactory.GetProfilingConnection(m_Options);
+            BOOST_ASSERT(m_ProfilingConnectionFactory);
             m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
         }
         catch (const Exception& e)
         {
             BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: "
-                                       << e.what();
+                                       << e.what() << std::endl;
         }
 
         // Move to the next state
@@ -229,13 +233,23 @@
 void ProfilingService::Reset()
 {
     // Reset the profiling service
-    m_CounterDirectory.Clear();
-    m_ProfilingConnection.reset();
-    m_StateMachine.Reset();
-    m_CounterIndex.clear();
-    m_CounterValues.clear();
+
+    // The order in which we reset/stop the components is not trivial!
+
+    // First stop the threads (Command Handler first)...
     m_CommandHandler.Stop();
     m_SendCounterPacket.Stop(false);
+
+    // ...then destroy the profiling connection...
+    m_ProfilingConnection.reset();
+
+    // ...then delete all the counter data and configuration...
+    m_CounterIndex.clear();
+    m_CounterValues.clear();
+    m_CounterDirectory.Clear();
+
+    // ...finally reset the profiling state machine
+    m_StateMachine.Reset();
 }
 
 } // namespace profiling
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 50a938e..edeb6bd 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -109,7 +109,7 @@
     }
     ~ProfilingService() = default;
 
-    // Protected method for testing
+    // Protected methods for testing
     void SwapProfilingConnectionFactory(ProfilingService& instance,
                                         IProfilingConnectionFactory* other,
                                         IProfilingConnectionFactory*& backup)
@@ -120,6 +120,10 @@
         backup = instance.m_ProfilingConnectionFactory.release();
         instance.m_ProfilingConnectionFactory.reset(other);
     }
+    IProfilingConnection* GetProfilingConnection(ProfilingService& instance)
+    {
+        return instance.m_ProfilingConnection.get();
+    }
 };
 
 } // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
index f186add..0fdcf10 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
@@ -5,6 +5,8 @@
 
 #include "RequestCounterDirectoryCommandHandler.hpp"
 
+#include <boost/assert.hpp>
+
 namespace armnn
 {
 
@@ -21,4 +23,4 @@
 
 } // namespace profiling
 
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index b9f2b18..e48da3e 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -945,7 +945,7 @@
     // Exception handling lock scope - Begin
     {
         // Lock the mutex to handle any exception coming from the send thread
-        std::unique_lock<std::mutex> lock(m_WaitMutex);
+        std::lock_guard<std::mutex> lock(m_WaitMutex);
 
         // Check if there's an exception to rethrow
         if (m_SendThreadException)
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index de92fb9..80d99dd 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -3,11 +3,10 @@
 // SPDX-License-Identifier: MIT
 //
 
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingTests.hpp"
 
 #include <CommandHandler.hpp>
 #include <CommandHandlerKey.hpp>
-#include <CommandHandlerFunctor.hpp>
 #include <CommandHandlerRegistry.hpp>
 #include <ConnectionAcknowledgedCommandHandler.hpp>
 #include <CounterDirectory.hpp>
@@ -19,7 +18,6 @@
 #include <PeriodicCounterCapture.hpp>
 #include <PeriodicCounterSelectionCommandHandler.hpp>
 #include <ProfilingStateMachine.hpp>
-#include <ProfilingService.hpp>
 #include <ProfilingUtils.hpp>
 #include <RequestCounterDirectoryCommandHandler.hpp>
 #include <Runtime.hpp>
@@ -27,21 +25,16 @@
 
 #include <armnn/Conversion.hpp>
 
-#include <Logging.hpp>
 #include <armnn/Utils.hpp>
 
 #include <boost/algorithm/string.hpp>
 #include <boost/numeric/conversion/cast.hpp>
-#include <boost/test/unit_test.hpp>
 
 #include <cstdint>
 #include <cstring>
-#include <iostream>
 #include <limits>
 #include <map>
 #include <random>
-#include <thread>
-#include <chrono>
 
 using namespace armnn::profiling;
 
@@ -94,59 +87,6 @@
     BOOST_CHECK(vect == expectedVect);
 }
 
-class TestProfilingConnectionBase :public IProfilingConnection
-{
-public:
-    TestProfilingConnectionBase() = default;
-    ~TestProfilingConnectionBase() = default;
-
-    bool IsOpen() const override { return true; }
-
-    void Close() override {}
-
-    bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
-
-    Packet ReadPacket(uint32_t timeout) override
-    {
-        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
-        std::unique_ptr<char[]> packetData;
-
-        // Return connection acknowledged packet
-        return { 65536, 0, packetData };
-    }
-};
-
-class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
-{
-public:
-    Packet ReadPacket(uint32_t timeout) {
-        if (readRequests < 3)
-        {
-            readRequests++;
-            throw armnn::TimeoutException("Simulate a timeout");
-        }
-        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
-        std::unique_ptr<char[]> packetData;
-
-        // Return connection acknowledged packet after three timeouts
-        return { 65536, 0, packetData };
-    }
-
-private:
-    int readRequests = 0;
-};
-
-class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase
-{
-public:
-
-    Packet ReadPacket(uint32_t timeout)
-    {
-        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
-        throw armnn::Exception(": Simulate a non timeout error");
-    }
-};
-
 BOOST_AUTO_TEST_CASE(CheckCommandHandler)
 {
     PacketVersionResolver packetVersionResolver;
@@ -180,7 +120,7 @@
     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
     // commandHandler1 should give up after one timeout
-    CommandHandler commandHandler1(1,
+    CommandHandler commandHandler1(10,
                                    true,
                                    commandHandlerRegistry,
                                    packetVersionResolver);
@@ -204,32 +144,24 @@
             break;
         }
 
-        std::this_thread::sleep_for(std::chrono::milliseconds(5));
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
     }
 
     commandHandler1.Stop();
 
     BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
 
-    CommandHandler commandHandler2(1,
+    CommandHandler commandHandler2(100,
                                    false,
                                    commandHandlerRegistry,
                                    packetVersionResolver);
 
     commandHandler2.Start(testProfilingConnectionArmnnError);
 
-    for (int i = 0; i < 100; i++)
-    {
-        if (!commandHandler2.IsRunning())
-        {
-            // commandHandler2 should stop once it encounters a non timing error
-            return;
-        }
+    // commandHandler2 should not stop once it encounters a non timing error
+    std::this_thread::sleep_for(std::chrono::milliseconds(500));
 
-        std::this_thread::sleep_for(std::chrono::milliseconds(5));
-    }
-
-    BOOST_ERROR("commandHandler2 has failed to stop");
+    BOOST_CHECK(commandHandler2.IsRunning());
     commandHandler2.Stop();
 }
 
@@ -300,33 +232,6 @@
     BOOST_CHECK(packetTest4.GetPacketClass() == 5);
 }
 
-// Create Derived Classes
-class TestFunctorA : public CommandHandlerFunctor
-{
-public:
-    using CommandHandlerFunctor::CommandHandlerFunctor;
-
-    int GetCount() { return m_Count; }
-
-    void operator()(const Packet& packet) override
-    {
-        m_Count++;
-    }
-
-private:
-    int m_Count = 0;
-};
-
-class TestFunctorB : public TestFunctorA
-{
-    using TestFunctorA::TestFunctorA;
-};
-
-class TestFunctorC : public TestFunctorA
-{
-    using TestFunctorA::TestFunctorA;
-};
-
 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
 {
     // Hard code the version as it will be the same during a single profiling session
@@ -455,6 +360,7 @@
         BOOST_TEST(resolvedVersion == expectedVersion);
     }
 }
+
 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
 {
     ProfilingState newState = ProfilingState::NotConnected;
@@ -664,32 +570,6 @@
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
 }
 
-struct LogLevelSwapper
-{
-public:
-    LogLevelSwapper(armnn::LogSeverity severity)
-    {
-        // Set the new log level
-        armnn::ConfigureLogging(true, true, severity);
-    }
-    ~LogLevelSwapper()
-    {
-        // The default log level for unit tests is "Fatal"
-        armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
-    }
-};
-
-struct CoutRedirect
-{
-public:
-    CoutRedirect(std::streambuf* newStreamBuffer)
-        : old(std::cout.rdbuf(newStreamBuffer)) {}
-    ~CoutRedirect() { std::cout.rdbuf(old); }
-
-private:
-    std::streambuf* old;
-};
-
 BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
 {
     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
@@ -705,7 +585,7 @@
 
     // Redirect the output to a local stream so that we can parse the warning message
     std::stringstream ss;
-    CoutRedirect coutRedirect(ss.rdbuf());
+    StreamRedirector streamRedirector(std::cout, ss.rdbuf());
     profilingService.Update();
     BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
 }
@@ -729,7 +609,7 @@
 
     // Redirect the output to a local stream so that we can parse the warning message
     std::stringstream ss;
-    CoutRedirect coutRedirect(ss.rdbuf());
+    StreamRedirector streamRedirector(std::cout, ss.rdbuf());
     profilingService.Update();
     BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
 }
@@ -1949,16 +1829,18 @@
     profilingState.TransitionToState(ProfilingState::WaitingForAck);
     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
     // command handler received packet on ProfilingState::WaitingForAck
-    commandHandler(packetA);
+    BOOST_CHECK_NO_THROW(commandHandler(packetA));
     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
 
     // command handler received packet on ProfilingState::Active
-    commandHandler(packetA);
+    BOOST_CHECK_NO_THROW(commandHandler(packetA));
     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
 
     // command handler received different packet
     const uint32_t differentPacketId = 0x40000;
     Packet packetB(differentPacketId, dataLength1, uniqueData1);
+    profilingState.TransitionToState(ProfilingState::NotConnected);
+    profilingState.TransitionToState(ProfilingState::WaitingForAck);
     ConnectionAcknowledgedCommandHandler differentCommandHandler(differentPacketId, version, profilingState);
     BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception);
 }
@@ -2333,62 +2215,17 @@
     BOOST_TEST(categoryRecordOffset ==  44);
 }
 
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
-    MockProfilingConnectionFactory()
-        : m_MockProfilingConnection(new MockProfilingConnection())
-    {}
-
-    IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
-    {
-        return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection);
-    }
-
-    MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; }
-
-private:
-    MockProfilingConnection* m_MockProfilingConnection;
-};
-
-class SwapProfilingConnectionFactoryHelper : public ProfilingService
-{
-public:
-    SwapProfilingConnectionFactoryHelper()
-        : ProfilingService()
-        , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
-        , m_BackupProfilingConnectionFactory(nullptr)
-    {
-        SwapProfilingConnectionFactory(ProfilingService::Instance(),
-                                       m_MockProfilingConnectionFactory.get(),
-                                       m_BackupProfilingConnectionFactory);
-    }
-    ~SwapProfilingConnectionFactoryHelper()
-    {
-        IProfilingConnectionFactory* temp = nullptr;
-        SwapProfilingConnectionFactory(ProfilingService::Instance(),
-                                       m_BackupProfilingConnectionFactory,
-                                       temp);
-    }
-
-    IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); }
-
-private:
-    IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
-    IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
-};
-
 BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
 {
     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
 
+    // Swap the profiling connection factory in the profiling service instance with our mock one
     SwapProfilingConnectionFactoryHelper helper;
-    MockProfilingConnectionFactory* mockProfilingConnectionFactory =
-            boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
-    BOOST_CHECK(mockProfilingConnectionFactory);
-    MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
-    BOOST_CHECK(mockProfilingConnection);
+
+    // Redirect the standard output to a local stream so that we can parse the warning message
+    std::stringstream ss;
+    StreamRedirector streamRedirector(std::cout, ss.rdbuf());
 
     // Calculate the size of a Stream Metadata packet
     std::string processName = GetProcessName().substr(0, 60);
@@ -2408,15 +2245,15 @@
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
     profilingService.Update();
 
-    // Redirect the output to a local stream so that we can parse the warning message
-    std::stringstream ss;
-    CoutRedirect coutRedirect(ss.rdbuf());
-
     // Wait for a bit to make sure that we get the packet
     std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
     // Check that the mock profiling connection contains one Stream Metadata packet
-    const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+    const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
 
@@ -2433,7 +2270,7 @@
     uint32_t header = ((packetFamily & 0x0000003F) << 26) |
                       ((packetId     & 0x000003FF) << 16);
 
-    // Connection Acknowledged Packet
+    // Create the Connection Acknowledged Packet
     Packet connectionAcknowledgedPacket(header);
 
     // Write the packet to the mock profiling connection
@@ -2441,23 +2278,23 @@
 
     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
     // the Connection Acknowledged packet gets processed by the profiling service
-    std::this_thread::sleep_for(std::chrono::seconds(1));
+    std::this_thread::sleep_for(std::chrono::seconds(2));
 
     // Check that the expected error has occurred and logged to the standard output
     BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist"));
 
     // The Connection Acknowledged Command Handler should not have updated the profiling state
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
 }
 
 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
 {
+    // Swap the profiling connection factory in the profiling service instance with our mock one
     SwapProfilingConnectionFactoryHelper helper;
-    MockProfilingConnectionFactory* mockProfilingConnectionFactory =
-            boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
-    BOOST_CHECK(mockProfilingConnectionFactory);
-    MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
-    BOOST_CHECK(mockProfilingConnection);
 
     // Calculate the size of a Stream Metadata packet
     std::string processName = GetProcessName().substr(0, 60);
@@ -2480,8 +2317,12 @@
     // Wait for a bit to make sure that we get the packet
     std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
     // Check that the mock profiling connection contains one Stream Metadata packet
-    const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+    const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
 
@@ -2498,7 +2339,7 @@
     uint32_t header = ((packetFamily & 0x0000003F) << 26) |
                       ((packetId     & 0x000003FF) << 16);
 
-    // Connection Acknowledged Packet
+    // Create the Connection Acknowledged Packet
     Packet connectionAcknowledgedPacket(header);
 
     // Write the packet to the mock profiling connection
@@ -2506,10 +2347,14 @@
 
     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
     // the Connection Acknowledged packet gets processed by the profiling service
-    std::this_thread::sleep_for(std::chrono::seconds(1));
+    std::this_thread::sleep_for(std::chrono::seconds(2));
 
     // The Connection Acknowledged Command Handler should have updated the profiling state accordingly
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
 }
 
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
new file mode 100644
index 0000000..3e6cf63
--- /dev/null
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -0,0 +1,200 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "SendCounterPacketTests.hpp"
+
+#include <CommandHandlerFunctor.hpp>
+#include <IProfilingConnection.hpp>
+#include <IProfilingConnectionFactory.hpp>
+#include <Logging.hpp>
+#include <ProfilingService.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <chrono>
+#include <iostream>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+struct LogLevelSwapper
+{
+public:
+    LogLevelSwapper(armnn::LogSeverity severity)
+    {
+        // Set the new log level
+        armnn::ConfigureLogging(true, true, severity);
+    }
+    ~LogLevelSwapper()
+    {
+        // The default log level for unit tests is "Fatal"
+        armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
+    }
+};
+
+struct CoutRedirect
+{
+public:
+    CoutRedirect(std::streambuf* newStreamBuffer)
+        : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
+    ~CoutRedirect() { std::cout.rdbuf(m_Old); }
+
+private:
+    std::streambuf* m_Old;
+};
+
+struct StreamRedirector
+{
+public:
+    StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
+        : m_Stream(stream)
+        , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
+    {}
+    ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
+
+private:
+    std::ostream& m_Stream;
+    std::streambuf* m_BackupBuffer;
+};
+
+class TestProfilingConnectionBase : public IProfilingConnection
+{
+public:
+    TestProfilingConnectionBase() = default;
+    ~TestProfilingConnectionBase() = default;
+
+    bool IsOpen() const override { return true; }
+
+    void Close() override {}
+
+    bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
+
+    Packet ReadPacket(uint32_t timeout) override
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+        // Return connection acknowledged packet
+        std::unique_ptr<char[]> packetData;
+        return Packet(65536, 0, packetData);
+    }
+};
+
+class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
+{
+public:
+    TestProfilingConnectionTimeoutError()
+        : m_ReadRequests(0)
+    {}
+
+    Packet ReadPacket(uint32_t timeout) override
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+        if (m_ReadRequests < 3)
+        {
+            m_ReadRequests++;
+            throw armnn::TimeoutException("Simulate a timeout error\n");
+        }
+
+        // Return connection acknowledged packet after three timeouts
+        std::unique_ptr<char[]> packetData;
+        return Packet(65536, 0, packetData);
+    }
+
+private:
+    int m_ReadRequests;
+};
+
+class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
+{
+public:
+    Packet ReadPacket(uint32_t timeout) override
+    {
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+        throw armnn::Exception("Simulate a non-timeout error");
+    }
+};
+
+class TestFunctorA : public CommandHandlerFunctor
+{
+public:
+    using CommandHandlerFunctor::CommandHandlerFunctor;
+
+    int GetCount() { return m_Count; }
+
+    void operator()(const Packet& packet) override
+    {
+        m_Count++;
+    }
+
+private:
+    int m_Count = 0;
+};
+
+class TestFunctorB : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
+
+class TestFunctorC : public TestFunctorA
+{
+    using TestFunctorA::TestFunctorA;
+};
+
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+    IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+    {
+        return std::make_unique<MockProfilingConnection>();
+    }
+};
+
+class SwapProfilingConnectionFactoryHelper : public ProfilingService
+{
+public:
+    using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
+
+    SwapProfilingConnectionFactoryHelper()
+        : ProfilingService()
+        , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
+        , m_BackupProfilingConnectionFactory(nullptr)
+    {
+        BOOST_CHECK(m_MockProfilingConnectionFactory);
+        SwapProfilingConnectionFactory(ProfilingService::Instance(),
+                                       m_MockProfilingConnectionFactory.get(),
+                                       m_BackupProfilingConnectionFactory);
+        BOOST_CHECK(m_BackupProfilingConnectionFactory);
+    }
+    ~SwapProfilingConnectionFactoryHelper()
+    {
+        BOOST_CHECK(m_BackupProfilingConnectionFactory);
+        IProfilingConnectionFactory* temp = nullptr;
+        SwapProfilingConnectionFactory(ProfilingService::Instance(),
+                                       m_BackupProfilingConnectionFactory,
+                                       temp);
+    }
+
+    MockProfilingConnection* GetMockProfilingConnection()
+    {
+        IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
+        return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
+    }
+
+private:
+    MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
+    IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index 1216420..00dad38 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -2322,7 +2322,7 @@
     BOOST_TEST(reservedBuffer.get());
 
     // Check that data was actually written to the profiling connection in any order
-    const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+    const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
     BOOST_TEST(writtenData.size() == 3);
     bool foundStreamMetaDataPacket =
         std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end();
@@ -2391,7 +2391,7 @@
     BOOST_CHECK_NO_THROW(sendCounterPacket.Stop());
 
     // Check that the buffer contains one Stream Metadata packet
-    const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+    const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
 }
@@ -2420,7 +2420,7 @@
     BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck));
 
     // Check that the buffer contains one Stream Metadata packet
-    const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+    const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
 
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 48bab02..871ca74 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -12,6 +12,7 @@
 #include <armnn/Optional.hpp>
 #include <armnn/Conversion.hpp>
 
+#include <boost/assert.hpp>
 #include <boost/numeric/conversion/cast.hpp>
 
 namespace armnn
@@ -19,6 +20,7 @@
 
 namespace profiling
 {
+
 class MockProfilingConnection : public IProfilingConnection
 {
 public:
@@ -28,9 +30,19 @@
         , m_Packet()
     {}
 
-    bool IsOpen() const override { return m_IsOpen; }
+    bool IsOpen() const override
+    {
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
-    void Close() override { m_IsOpen = false; }
+        return m_IsOpen;
+    }
+
+    void Close() override
+    {
+        std::lock_guard<std::mutex> lock(m_Mutex);
+
+        m_IsOpen = false;
+    }
 
     bool WritePacket(const unsigned char* buffer, uint32_t length) override
     {
@@ -39,11 +51,15 @@
             return false;
         }
 
+        std::lock_guard<std::mutex> lock(m_Mutex);
+
         m_WrittenData.push_back(length);
         return true;
     }
     bool WritePacket(Packet&& packet)
     {
+        std::lock_guard<std::mutex> lock(m_Mutex);
+
         m_Packet = std::move(packet);
         return true;
     }
@@ -51,19 +67,32 @@
     Packet ReadPacket(uint32_t timeout) override
     {
         // Simulate a delay in the reading process
-        std::this_thread::sleep_for(std::chrono::milliseconds(500));
+        std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         return std::move(m_Packet);
     }
 
-    const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; }
+    const std::vector<uint32_t> GetWrittenData() const
+    {
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
-    void Clear() { m_WrittenData.clear(); }
+        return m_WrittenData;
+    }
+
+    void Clear()
+    {
+        std::lock_guard<std::mutex> lock(m_Mutex);
+
+        m_WrittenData.clear();
+    }
 
 private:
     bool m_IsOpen;
     std::vector<uint32_t> m_WrittenData;
     Packet m_Packet;
+    mutable std::mutex m_Mutex;
 };
 
 class MockPacketBuffer : public IPacketBuffer
@@ -162,7 +191,7 @@
 
     IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
     {
-        std::unique_lock<std::mutex> lock(m_Mutex);
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         reservedSize = 0;
         if (requestedSize > m_MaxBufferSize)
@@ -176,7 +205,7 @@
 
     void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override
     {
-        std::unique_lock<std::mutex> lock(m_Mutex);
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         packetBuffer->Commit(size);
         m_BufferList.push_back(std::move(packetBuffer));
@@ -185,14 +214,14 @@
 
     void Release(IPacketBufferPtr& packetBuffer) override
     {
-        std::unique_lock<std::mutex> lock(m_Mutex);
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         packetBuffer->Release();
     }
 
     IPacketBufferPtr GetReadableBuffer() override
     {
-        std::unique_lock<std::mutex> lock(m_Mutex);
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         if (m_BufferList.empty())
         {
@@ -206,7 +235,7 @@
 
     void MarkRead(IPacketBufferPtr& packetBuffer) override
     {
-        std::unique_lock<std::mutex> lock(m_Mutex);
+        std::lock_guard<std::mutex> lock(m_Mutex);
 
         m_ReadSize += packetBuffer->GetSize();
         packetBuffer->MarkRead();