IVGCVSW-3937 Refactor the command thread

 * Integrated the Join method into Stop
 * Updated the unit tests accordingly
 * General code refactoring

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: If8537e77b3d3ff2b780f58a07df01191a91d83d2
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp
index bd4aa96..320e4bc 100644
--- a/src/profiling/CommandThread.cpp
+++ b/src/profiling/CommandThread.cpp
@@ -12,69 +12,26 @@
 namespace profiling
 {
 
-CommandThread::CommandThread(uint32_t timeout,
-                             bool stopAfterTimeout,
-                             CommandHandlerRegistry& commandHandlerRegistry,
-                             PacketVersionResolver& packetVersionResolver,
-                             IProfilingConnection& socketProfilingConnection)
-    : m_Timeout(timeout)
-    , m_StopAfterTimeout(stopAfterTimeout)
-    , m_IsRunning(false)
-    , m_CommandHandlerRegistry(commandHandlerRegistry)
-    , m_PacketVersionResolver(packetVersionResolver)
-    , m_SocketProfilingConnection(socketProfilingConnection)
-{};
-
-void CommandThread::WaitForPacket()
-{
-    do {
-        try
-        {
-            Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
-            Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
-
-            CommandHandlerFunctor* commandHandlerFunctor =
-                m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
-            commandHandlerFunctor->operator()(packet);
-        }
-        catch(const armnn::TimeoutException&)
-        {
-            if(m_StopAfterTimeout)
-            {
-                m_IsRunning.store(false, std::memory_order_relaxed);
-                return;
-            }
-        }
-        catch(...)
-        {
-            //might want to differentiate the errors more
-            m_IsRunning.store(false, std::memory_order_relaxed);
-            return;
-        }
-
-    } while(m_KeepRunning.load(std::memory_order_relaxed));
-
-    m_IsRunning.store(false, std::memory_order_relaxed);
-}
-
 void CommandThread::Start()
 {
-    if (!m_CommandThread.joinable() && !IsRunning())
+    if (IsRunning())
     {
-        m_IsRunning.store(true, std::memory_order_relaxed);
-        m_KeepRunning.store(true, std::memory_order_relaxed);
-        m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+        return;
     }
+
+    m_IsRunning.store(true, std::memory_order_relaxed);
+    m_KeepRunning.store(true, std::memory_order_relaxed);
+    m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
 }
 
 void CommandThread::Stop()
 {
     m_KeepRunning.store(false, std::memory_order_relaxed);
-}
 
-void CommandThread::Join()
-{
-    m_CommandThread.join();
+    if (m_CommandThread.joinable())
+    {
+        m_CommandThread.join();
+    }
 }
 
 bool CommandThread::IsRunning() const
@@ -82,16 +39,49 @@
     return m_IsRunning.load(std::memory_order_relaxed);
 }
 
-bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+void CommandThread::SetTimeout(uint32_t timeout)
 {
-    if (!IsRunning())
-    {
-        m_StopAfterTimeout = stopAfterTimeout;
-        return true;
-    }
-    return false;
+    m_Timeout.store(timeout, std::memory_order_relaxed);
 }
 
-}//namespace profiling
+void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout)
+{
+    m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
+}
 
-}//namespace armnn
+void CommandThread::WaitForPacket()
+{
+    do
+    {
+        try
+        {
+            Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+            Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+            CommandHandlerFunctor* commandHandlerFunctor =
+                m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+            BOOST_ASSERT(commandHandlerFunctor);
+            commandHandlerFunctor->operator()(packet);
+        }
+        catch (const armnn::TimeoutException&)
+        {
+            if (m_StopAfterTimeout)
+            {
+                m_KeepRunning.store(false, std::memory_order_relaxed);
+            }
+        }
+        catch (...)
+        {
+            // Might want to differentiate the errors more
+            m_KeepRunning.store(false, std::memory_order_relaxed);
+        }
+
+    }
+    while (m_KeepRunning.load(std::memory_order_relaxed));
+
+    m_IsRunning.store(false, std::memory_order_relaxed);
+}
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp
index 6237cd2..0456ba4 100644
--- a/src/profiling/CommandThread.hpp
+++ b/src/profiling/CommandThread.hpp
@@ -26,19 +26,31 @@
                   bool stopAfterTimeout,
                   CommandHandlerRegistry& commandHandlerRegistry,
                   PacketVersionResolver& packetVersionResolver,
-                  IProfilingConnection& socketProfilingConnection);
+                  IProfilingConnection& socketProfilingConnection)
+        : m_Timeout(timeout)
+        , m_StopAfterTimeout(stopAfterTimeout)
+        , m_IsRunning(false)
+        , m_KeepRunning(false)
+        , m_CommandThread()
+        , m_CommandHandlerRegistry(commandHandlerRegistry)
+        , m_PacketVersionResolver(packetVersionResolver)
+        , m_SocketProfilingConnection(socketProfilingConnection)
+    {}
+    ~CommandThread() { Stop(); }
 
     void Start();
     void Stop();
-    void Join();
+
     bool IsRunning() const;
-    bool StopAfterTimeout(bool StopAfterTimeout);
+
+    void SetTimeout(uint32_t timeout);
+    void SetStopAfterTimeout(bool stopAfterTimeout);
 
 private:
     void WaitForPacket();
 
-    uint32_t m_Timeout;
-    bool m_StopAfterTimeout;
+    std::atomic<uint32_t> m_Timeout;
+    std::atomic<bool> m_StopAfterTimeout;
     std::atomic<bool> m_IsRunning;
     std::atomic<bool> m_KeepRunning;
     std::thread m_CommandThread;
@@ -48,6 +60,6 @@
     IProfilingConnection& m_SocketProfilingConnection;
 };
 
-}//namespace profiling
+} // namespace profiling
 
-}//namespace armnn
\ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index d14791c..9dd7cd3 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -174,7 +174,6 @@
         commandThread0.Start();
 
         commandThread0.Stop();
-        commandThread0.Join();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
 
@@ -188,11 +187,15 @@
                                      testProfilingConnectionTimeOutError);
 
         commandThread1.Start();
-        commandThread1.Join();
+
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        BOOST_CHECK(!commandThread1.IsRunning());
+        commandThread1.Stop();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
         //now commandThread1 should persist after a timeout
-        commandThread1.StopAfterTimeout(false);
+        commandThread1.SetStopAfterTimeout(false);
         commandThread1.Start();
 
         for (int i = 0; i < 100; i++)
@@ -208,11 +211,9 @@
         }
 
         commandThread1.Stop();
-        commandThread1.Join();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
 
-
         CommandThread commandThread2(1,
                                      false,
                                      commandHandlerRegistry,
@@ -226,13 +227,13 @@
             if (!commandThread2.IsRunning())
             {
                 //commandThread2 should stop once it encounters a non timing error
-                commandThread2.Join();
                 return;
             }
             std::this_thread::sleep_for(std::chrono::milliseconds(5));
         }
 
         BOOST_ERROR("commandThread2 has failed to stop");
+        commandThread2.Stop();
 }
 
 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)