IVGCVSW-3963 Implement the Request Counter Directory Handler

 * Integrated the RequestCounterDirectoryCommandHandler in the
   ProfilingService class
 * Code refactoring
 * Added/Updated unit tests

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I60d9f8acf166e29b3dabc921dbdb8149461bd85f
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index edeb6bd..0e66924 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -13,6 +13,7 @@
 #include "BufferManager.hpp"
 #include "SendCounterPacket.hpp"
 #include "ConnectionAcknowledgedCommandHandler.hpp"
+#include "RequestCounterDirectoryCommandHandler.hpp"
 
 namespace armnn
 {
@@ -81,6 +82,7 @@
     BufferManager m_BufferManager;
     SendCounterPacket m_SendCounterPacket;
     ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
+    RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
 
 protected:
     // Default constructor/destructor kept protected for testing
@@ -103,9 +105,17 @@
         , m_ConnectionAcknowledgedCommandHandler(1,
                                                  m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(),
                                                  m_StateMachine)
+        , m_RequestCounterDirectoryCommandHandler(3,
+                                                  m_PacketVersionResolver.ResolvePacketVersion(3).GetEncodedValue(),
+                                                  m_CounterDirectory,
+                                                  m_SendCounterPacket,
+                                                  m_StateMachine)
     {
         // Register the "Connection Acknowledged" command handler
         m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
+
+        // Register the "Request Counter Directory" command handler
+        m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
     }
     ~ProfilingService() = default;
 
@@ -124,6 +134,10 @@
     {
         return instance.m_ProfilingConnection.get();
     }
+    void TransitionToState(ProfilingService& instance, ProfilingState newState)
+    {
+        instance.m_StateMachine.TransitionToState(newState);
+    }
 };
 
 } // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
index 0fdcf10..e85acb4 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
@@ -5,7 +5,7 @@
 
 #include "RequestCounterDirectoryCommandHandler.hpp"
 
-#include <boost/assert.hpp>
+#include <boost/format.hpp>
 
 namespace armnn
 {
@@ -15,10 +15,36 @@
 
 void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet)
 {
-    BOOST_ASSERT(packet.GetLength() == 0);
+    ProfilingState currentState = m_StateMachine.GetCurrentState();
+    switch (currentState)
+    {
+    case ProfilingState::Uninitialised:
+    case ProfilingState::NotConnected:
+    case ProfilingState::WaitingForAck:
+        throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an "
+                                                        "wrong state: %1%")
+                                          % GetProfilingStateName(currentState)));
+    case ProfilingState::Active:
+        // Process the packet
+        if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 3u))
+        {
+            throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 3 but "
+                                                                           "received family = %1%, id = %2%")
+                                                  % packet.GetPacketFamily()
+                                                  % packet.GetPacketId()));
+        }
 
-    // Write packet to Counter Stream Buffer
-    m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory);
+        // Write a Counter Directory packet to the Counter Stream Buffer
+        m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory);
+
+        // Notify the Send Thread that new data is available in the Counter Stream Buffer
+        m_SendCounterPacket.SetReadyToRead();
+
+        break;
+    default:
+        throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+                                          % static_cast<int>(currentState)));
+    }
 }
 
 } // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.hpp b/src/profiling/RequestCounterDirectoryCommandHandler.hpp
index a03300a..02bf64d 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.hpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.hpp
@@ -8,6 +8,7 @@
 #include "CommandHandlerFunctor.hpp"
 #include "ISendCounterPacket.hpp"
 #include "Packet.hpp"
+#include "ProfilingStateMachine.hpp"
 
 namespace armnn
 {
@@ -19,23 +20,25 @@
 {
 
 public:
-    RequestCounterDirectoryCommandHandler(uint32_t packetId, uint32_t version,
+    RequestCounterDirectoryCommandHandler(uint32_t packetId,
+                                          uint32_t version,
                                           ICounterDirectory& counterDirectory,
-                                          ISendCounterPacket& sendCounterPacket)
-    : CommandHandlerFunctor(packetId, version),
-    m_CounterDirectory(counterDirectory),
-    m_SendCounterPacket(sendCounterPacket)
+                                          ISendCounterPacket& sendCounterPacket,
+                                          ProfilingStateMachine& profilingStateMachine)
+        : CommandHandlerFunctor(packetId, version)
+        , m_CounterDirectory(counterDirectory)
+        , m_SendCounterPacket(sendCounterPacket)
+        , m_StateMachine(profilingStateMachine)
     {}
 
     void operator()(const Packet& packet) override;
 
-
 private:
-    ICounterDirectory& m_CounterDirectory;
+    const ICounterDirectory& m_CounterDirectory;
     ISendCounterPacket& m_SendCounterPacket;
+    const ProfilingStateMachine& m_StateMachine;
 };
 
 } // namespace profiling
 
 } // namespace armnn
-
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 80d99dd..57a11d0 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -2119,75 +2119,97 @@
     BOOST_TEST((valueB * numSteps) == readValue);
 }
 
-BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest0)
+BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
 {
     using boost::numeric_cast;
 
-    ProfilingStateMachine profilingStateMachine;
-
-    const uint32_t packetId = 0x30000;
+    const uint32_t packetId = 3;
     const uint32_t version = 1;
-
-    std::unique_ptr<char[]> packetData;
-
-    Packet packetA(packetId, 0, packetData);
-
+    ProfilingStateMachine profilingStateMachine;
+    CounterDirectory counterDirectory;
     MockBufferManager mockBuffer(1024);
     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
+    RequestCounterDirectoryCommandHandler commandHandler(packetId,
+                                                         version,
+                                                         counterDirectory,
+                                                         sendCounterPacket,
+                                                         profilingStateMachine);
 
-    CounterDirectory counterDirectory;
+    const uint32_t wrongPacketId = 47;
+    const uint32_t wrongHeader = (wrongPacketId & 0x000003FF) << 16;
 
-    RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket);
-    commandHandler(packetA);
+    Packet wrongPacket(wrongHeader);
+
+    profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::InvalidArgumentException); // Wrong packet
+
+    const uint32_t rightHeader = (packetId & 0x000003FF) << 16;
+
+    Packet rightPacket(rightHeader);
+
+    BOOST_CHECK_NO_THROW(commandHandler(rightPacket)); // Right packet
 
     auto readBuffer = mockBuffer.GetReadableBuffer();
 
     uint32_t headerWord0 = ReadUint32(readBuffer, 0);
     uint32_t headerWord1 = ReadUint32(readBuffer, 4);
 
-    BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);  // packet family
-    BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id
-    BOOST_TEST(headerWord1 == 24);                  // data length
+    BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0);  // packet family
+    BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2);  // packet id
+    BOOST_TEST(headerWord1 == 24);                        // data length
 
     uint32_t bodyHeaderWord0 = ReadUint32(readBuffer,  8);
     uint16_t deviceRecordCount = numeric_cast<uint16_t>(bodyHeaderWord0 >> 16);
     BOOST_TEST(deviceRecordCount == 0); // device_records_count
 }
 
-BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
+BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest2)
 {
     using boost::numeric_cast;
 
-    ProfilingStateMachine profilingStateMachine;
-
-    const uint32_t packetId = 0x30000;
+    const uint32_t packetId = 3;
     const uint32_t version = 1;
-
-    std::unique_ptr<char[]> packetData;
-
-    Packet packetA(packetId, 0, packetData);
-
+    ProfilingStateMachine profilingStateMachine;
+    CounterDirectory counterDirectory;
     MockBufferManager mockBuffer(1024);
     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
+    RequestCounterDirectoryCommandHandler commandHandler(packetId,
+                                                         version,
+                                                         counterDirectory,
+                                                         sendCounterPacket,
+                                                         profilingStateMachine);
+    const uint32_t header = (packetId & 0x000003FF) << 16;
+    Packet packet(header);
 
-    CounterDirectory counterDirectory;
     const Device* device = counterDirectory.RegisterDevice("deviceA", 1);
     const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA");
     counterDirectory.RegisterCategory("categoryA", device->m_Uid, counterSet->m_Uid);
     counterDirectory.RegisterCounter("categoryA", 0, 1, 2.0f, "counterA", "descA");
     counterDirectory.RegisterCounter("categoryA", 1, 1, 3.0f, "counterB", "descB");
 
-    RequestCounterDirectoryCommandHandler commandHandler(packetId, version, counterDirectory, sendCounterPacket);
-    commandHandler(packetA);
+    profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException); // Wrong profiling state
+    profilingStateMachine.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK_NO_THROW(commandHandler(packet));
 
     auto readBuffer = mockBuffer.GetReadableBuffer();
 
     uint32_t headerWord0 = ReadUint32(readBuffer, 0);
     uint32_t headerWord1 = ReadUint32(readBuffer, 4);
 
-    BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);  // packet family
-    BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 2); // packet id
-    BOOST_TEST(headerWord1 == 240);                 // data length
+    BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0);  // packet family
+    BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2);  // packet id
+    BOOST_TEST(headerWord1 == 240);                       // data length
 
     uint32_t bodyHeaderWord0 = ReadUint32(readBuffer,  8);
     uint32_t bodyHeaderWord1 = ReadUint32(readBuffer, 12);
@@ -2357,4 +2379,131 @@
     profilingService.ResetExternalProfilingOptions(options, true);
 }
 
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
+{
+    // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+    LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Redirect the standard output to a local stream so that we can parse the warning message
+    std::stringstream ss;
+    StreamRedirector streamRedirector(std::cout, ss.rdbuf());
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    helper.ForceTransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the threads
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
+    // reply from an external profiling service
+
+    // Request Counter Directory packet header (word 0, word 1 is always zero):
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 123; // Wrong packet id!!!
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    // Create the Request Counter Directory packet
+    Packet requestCounterDirectoryPacket(header);
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
+
+    // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+    // the Create the Request Counter packet gets processed by the profiling service
+    std::this_thread::sleep_for(std::chrono::seconds(2));
+
+    // Check that the expected error has occurred and logged to the standard output
+    BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist"));
+
+    // The Connection Acknowledged Command Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
+{
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the threads
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
+    // reply from an external profiling service
+
+    // Request Counter Directory packet header (word 0, word 1 is always zero):
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 3;
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    // Create the Request Counter Directory packet
+    Packet requestCounterDirectoryPacket(header);
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
+
+    // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
+    // the Create the Request Counter packet gets processed by the profiling service
+    std::this_thread::sleep_for(std::chrono::seconds(2));
+
+    // The Connection Acknowledged Command Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Check that the mock profiling connection contains one Counter Directory packet
+    const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+    BOOST_TEST(writtenData.size() == 1);
+    BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
index 3e6cf63..e168616 100644
--- a/src/profiling/test/ProfilingTests.hpp
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -40,17 +40,6 @@
     }
 };
 
-struct CoutRedirect
-{
-public:
-    CoutRedirect(std::streambuf* newStreamBuffer)
-        : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
-    ~CoutRedirect() { std::cout.rdbuf(m_Old); }
-
-private:
-    std::streambuf* m_Old;
-};
-
 struct StreamRedirector
 {
 public:
@@ -190,6 +179,11 @@
         return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
     }
 
+    void ForceTransitionToState(ProfilingState newState)
+    {
+        TransitionToState(ProfilingService::Instance(), newState);
+    }
+
 private:
     MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
     IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;