IVGCVSW-3937 Refactor and improve the CommandHandleRegistry class

 * Added simplified RegisterFunctor method
 * Code refactoring
 * Updated the unit tests accordingly

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: Iee941d898facd9c1ab5366e87c611c99a0468830
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp
index 5eddfd5..4978405 100644
--- a/src/profiling/CommandHandler.cpp
+++ b/src/profiling/CommandHandler.cpp
@@ -18,14 +18,14 @@
         return;
     }
 
-    m_IsRunning.store(true, std::memory_order_relaxed);
-    m_KeepRunning.store(true, std::memory_order_relaxed);
+    m_IsRunning.store(true);
+    m_KeepRunning.store(true);
     m_CommandThread = std::thread(&CommandHandler::HandleCommands, this, std::ref(profilingConnection));
 }
 
 void CommandHandler::Stop()
 {
-    m_KeepRunning.store(false, std::memory_order_relaxed);
+    m_KeepRunning.store(false);
 
     if (m_CommandThread.joinable())
     {
@@ -33,21 +33,6 @@
     }
 }
 
-bool CommandHandler::IsRunning() const
-{
-    return m_IsRunning.load(std::memory_order_relaxed);
-}
-
-void CommandHandler::SetTimeout(uint32_t timeout)
-{
-    m_Timeout.store(timeout, std::memory_order_relaxed);
-}
-
-void CommandHandler::SetStopAfterTimeout(bool stopAfterTimeout)
-{
-    m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
-}
-
 void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
 {
     do
@@ -72,12 +57,12 @@
         catch (...)
         {
             // Might want to differentiate the errors more
-            m_KeepRunning.store(false, std::memory_order_relaxed);
+            m_KeepRunning.store(false);
         }
     }
-    while (m_KeepRunning.load(std::memory_order_relaxed));
+    while (m_KeepRunning.load());
 
-    m_IsRunning.store(false, std::memory_order_relaxed);
+    m_IsRunning.store(false);
 }
 
 } // namespace profiling
diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp
index 598eabd..0cc2342 100644
--- a/src/profiling/CommandHandler.hpp
+++ b/src/profiling/CommandHandler.hpp
@@ -35,13 +35,12 @@
     {}
     ~CommandHandler() { Stop(); }
 
+    void SetTimeout(uint32_t timeout) { m_Timeout.store(timeout); }
+    void SetStopAfterTimeout(bool stopAfterTimeout) { m_StopAfterTimeout.store(stopAfterTimeout); }
+
     void Start(IProfilingConnection& profilingConnection);
     void Stop();
-
-    bool IsRunning() const;
-
-    void SetTimeout(uint32_t timeout);
-    void SetStopAfterTimeout(bool stopAfterTimeout);
+    bool IsRunning() const { return m_IsRunning.load(); }
 
 private:
     void HandleCommands(IProfilingConnection& profilingConnection);
diff --git a/src/profiling/CommandHandlerFunctor.hpp b/src/profiling/CommandHandlerFunctor.hpp
index a9a59c1..2e1e05f 100644
--- a/src/profiling/CommandHandlerFunctor.hpp
+++ b/src/profiling/CommandHandlerFunctor.hpp
@@ -18,12 +18,15 @@
 class CommandHandlerFunctor
 {
 public:
-    CommandHandlerFunctor(uint32_t packetId, uint32_t version) : m_PacketId(packetId), m_Version(version) {};
+    CommandHandlerFunctor(uint32_t packetId, uint32_t version)
+        : m_PacketId(packetId)
+        , m_Version(version)
+    {}
 
     uint32_t GetPacketId() const;
     uint32_t GetVersion()  const;
 
-    virtual void operator()(const Packet& packet) {};
+    virtual void operator()(const Packet& packet) {}
 
 private:
     uint32_t m_PacketId;
diff --git a/src/profiling/CommandHandlerRegistry.cpp b/src/profiling/CommandHandlerRegistry.cpp
index 9731347..bd9b318 100644
--- a/src/profiling/CommandHandlerRegistry.cpp
+++ b/src/profiling/CommandHandlerRegistry.cpp
@@ -6,7 +6,7 @@
 #include "CommandHandlerRegistry.hpp"
 
 #include <boost/assert.hpp>
-#include <boost/log/trivial.hpp>
+#include <boost/format.hpp>
 
 namespace armnn
 {
@@ -16,11 +16,19 @@
 
 void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version)
 {
-    BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr.");
+    BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr");
+
     CommandHandlerKey key(packetId, version);
     registry[key] = functor;
 }
 
+void CommandHandlerRegistry::RegisterFunctor(CommandHandlerFunctor* functor)
+{
+    BOOST_ASSERT_MSG(functor, "Provided functor should not be a nullptr");
+
+    RegisterFunctor(functor, functor->GetPacketId(), functor->GetVersion());
+}
+
 CommandHandlerFunctor* CommandHandlerRegistry::GetFunctor(uint32_t packetId, uint32_t version) const
 {
     CommandHandlerKey key(packetId, version);
@@ -28,10 +36,22 @@
     // Check that the requested key exists
     if (registry.find(key) == registry.end())
     {
-        throw armnn::Exception("Functor with requested PacketId or Version does not exist.");
+        throw armnn::InvalidArgumentException(
+                    boost::str(boost::format("Functor with requested PacketId=%1% and Version=%2% does not exist")
+                               % packetId
+                               % version));
     }
 
-    return registry.at(key);
+    CommandHandlerFunctor* commandHandlerFunctor = registry.at(key);
+    if (commandHandlerFunctor == nullptr)
+    {
+        throw RuntimeException(
+                    boost::str(boost::format("Invalid functor registered for PacketId=%1% and Version=%2%")
+                               % packetId
+                               % version));
+    }
+
+    return commandHandlerFunctor;
 }
 
 } // namespace profiling
diff --git a/src/profiling/CommandHandlerRegistry.hpp b/src/profiling/CommandHandlerRegistry.hpp
index 61d45b0..9d514bf 100644
--- a/src/profiling/CommandHandlerRegistry.hpp
+++ b/src/profiling/CommandHandlerRegistry.hpp
@@ -36,6 +36,8 @@
 
     void RegisterFunctor(CommandHandlerFunctor* functor, uint32_t packetId, uint32_t version);
 
+    void RegisterFunctor(CommandHandlerFunctor* functor);
+
     CommandHandlerFunctor* GetFunctor(uint32_t packetId, uint32_t version) const;
 
 private:
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
index 0f83a31..f90b601 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
@@ -17,10 +17,12 @@
 {
     if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
     {
-        throw armnn::Exception(std::string("Expected Packet family = 0, id = 1 but received family =")
-                               + std::to_string(packet.GetPacketFamily())
-                               +" id = " + std::to_string(packet.GetPacketId()));
+        throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ")
+                                              + std::to_string(packet.GetPacketFamily())
+                                              + " id = " + std::to_string(packet.GetPacketId()));
     }
+
+    // Once a Connection Acknowledged packet has been received, move to the Active state immediately
     m_StateMachine.TransitionToState(ProfilingState::Active);
 }
 
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
index f61495e..d0dc07a 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.hpp
@@ -15,14 +15,16 @@
 namespace profiling
 {
 
-class ConnectionAcknowledgedCommandHandler : public CommandHandlerFunctor
+class ConnectionAcknowledgedCommandHandler final : public CommandHandlerFunctor
 {
 
 public:
     ConnectionAcknowledgedCommandHandler(uint32_t packetId,
                                          uint32_t version,
                                          ProfilingStateMachine& profilingStateMachine)
-        : CommandHandlerFunctor(packetId, version), m_StateMachine(profilingStateMachine) {}
+        : CommandHandlerFunctor(packetId, version)
+        , m_StateMachine(profilingStateMachine)
+    {}
 
     void operator()(const Packet& packet) override;
 
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index ba1e6cf..91568d1 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -154,7 +154,7 @@
     ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine);
     CommandHandlerRegistry commandHandlerRegistry;
 
-    commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304);
+    commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler);
 
     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
@@ -388,9 +388,9 @@
     CommandHandlerRegistry registry;
 
     // Register multiple different derived classes
-    registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
-    registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
-    registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
+    registry.RegisterFunctor(&testFunctorA);
+    registry.RegisterFunctor(&testFunctorB);
+    registry.RegisterFunctor(&testFunctorC);
 
     std::unique_ptr<char[]> packetDataA;
     std::unique_ptr<char[]> packetDataB;