IVGCVSW-4393 Register backend counters

Signed-off-by: David Monahan <david.monahan@arm.com>
Change-Id: I419ecc2fce4b7e0fcaeb6d1f9cb687c0b660125d
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
diff --git a/src/profiling/CounterDirectory.hpp b/src/profiling/CounterDirectory.hpp
index b0ddbce..22bae89 100644
--- a/src/profiling/CounterDirectory.hpp
+++ b/src/profiling/CounterDirectory.hpp
@@ -6,9 +6,7 @@
 #pragma once
 
 #include "ICounterDirectory.hpp"
-
-#include <armnn/Optional.hpp>
-#include <armnn/BackendId.hpp>
+#include "ICounterRegistry.hpp"
 
 #include <string>
 #include <unordered_set>
@@ -22,7 +20,7 @@
 namespace profiling
 {
 
-class CounterDirectory final : public ICounterDirectory
+class CounterDirectory final : public ICounterDirectory, public ICounterRegistry
 {
 public:
     CounterDirectory() = default;
@@ -31,13 +29,13 @@
     // Register profiling objects
     const Category*   RegisterCategory  (const std::string& categoryName,
                                          const Optional<uint16_t>& deviceUid = EmptyOptional(),
-                                         const Optional<uint16_t>& counterSetUid = EmptyOptional());
+                                         const Optional<uint16_t>& counterSetUid = EmptyOptional()) override;
     const Device*     RegisterDevice    (const std::string& deviceName,
                                          uint16_t cores = 0,
-                                         const Optional<std::string>& parentCategoryName = EmptyOptional());
+                                         const Optional<std::string>& parentCategoryName = EmptyOptional()) override;
     const CounterSet* RegisterCounterSet(const std::string& counterSetName,
                                          uint16_t count = 0,
-                                         const Optional<std::string>& parentCategoryName = EmptyOptional());
+                                         const Optional<std::string>& parentCategoryName = EmptyOptional()) override;
     const Counter* RegisterCounter(const BackendId& backendId,
                                    const uint16_t uid,
                                    const std::string& parentCategoryName,
@@ -49,7 +47,7 @@
                                    const Optional<std::string>& units = EmptyOptional(),
                                    const Optional<uint16_t>& numberOfCores = EmptyOptional(),
                                    const Optional<uint16_t>& deviceUid = EmptyOptional(),
-                                   const Optional<uint16_t>& counterSetUid = EmptyOptional());
+                                   const Optional<uint16_t>& counterSetUid = EmptyOptional()) override;
 
     // Getters for counts
     uint16_t GetCategoryCount()   const override { return boost::numeric_cast<uint16_t>(m_Categories.size());  }
diff --git a/src/profiling/CounterIdMap.cpp b/src/profiling/CounterIdMap.cpp
index 8ee80f9..8626005 100644
--- a/src/profiling/CounterIdMap.cpp
+++ b/src/profiling/CounterIdMap.cpp
@@ -21,6 +21,12 @@
     m_BackendCounterIdMap[backendIdPair] = globalCounterId;
 }
 
+void CounterIdMap::Reset()
+{
+    m_GlobalCounterIdMap.clear();
+    m_BackendCounterIdMap.clear();
+}
+
 uint16_t CounterIdMap::GetGlobalId(uint16_t backendCounterId, const armnn::BackendId& backendId) const
 {
     std::pair<uint16_t, armnn::BackendId> backendIdPair(backendCounterId, backendId);
diff --git a/src/profiling/CounterIdMap.hpp b/src/profiling/CounterIdMap.hpp
index cb6b9c9..5c1a6ea 100644
--- a/src/profiling/CounterIdMap.hpp
+++ b/src/profiling/CounterIdMap.hpp
@@ -26,6 +26,7 @@
     virtual void RegisterMapping(uint16_t globalCounterId,
                                  uint16_t backendCounterId,
                                  const armnn::BackendId& backendId) = 0;
+    virtual void Reset() = 0;
     virtual ~IRegisterCounterMapping() {}
 };
 
@@ -38,6 +39,7 @@
     void RegisterMapping(uint16_t globalCounterId,
                          uint16_t backendCounterId,
                          const armnn::BackendId& backendId) override;
+    void Reset() override;
     uint16_t GetGlobalId(uint16_t backendCounterId, const armnn::BackendId& backendId) const override;
     const std::pair<uint16_t, armnn::BackendId>& GetBackendId(uint16_t globalCounterId) const override;
 private:
diff --git a/src/profiling/ICounterRegistry.hpp b/src/profiling/ICounterRegistry.hpp
new file mode 100644
index 0000000..75bc8ef
--- /dev/null
+++ b/src/profiling/ICounterRegistry.hpp
@@ -0,0 +1,52 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Optional.hpp>
+#include <armnn/BackendId.hpp>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class ICounterRegistry
+{
+public:
+    virtual ~ICounterRegistry() {}
+
+    // Register profiling objects
+    virtual const Category*   RegisterCategory  (const std::string& categoryName,
+                                                 const Optional<uint16_t>& deviceUid,
+                                                 const Optional<uint16_t>& counterSetUid) = 0;
+
+    virtual const Device*     RegisterDevice    (const std::string& deviceName,
+                                                 uint16_t cores,
+                                                 const Optional<std::string>& parentCategoryName) = 0;
+
+    virtual const CounterSet* RegisterCounterSet(const std::string& counterSetName,
+                                                 uint16_t count,
+                                                 const Optional<std::string>& parentCategoryName) = 0;
+
+    virtual const Counter* RegisterCounter(const BackendId& backendId,
+                                           const uint16_t uid,
+                                           const std::string& parentCategoryName,
+                                           uint16_t counterClass,
+                                           uint16_t interpolation,
+                                           double multiplier,
+                                           const std::string& name,
+                                           const std::string& description,
+                                           const Optional<std::string>& units,
+                                           const Optional<uint16_t>& numberOfCores,
+                                           const Optional<uint16_t>& deviceUid,
+                                           const Optional<uint16_t>& counterSetUid) = 0;
+
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index b06f149..926e54a 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -186,6 +186,11 @@
     return m_CounterDirectory;
 }
 
+ICounterRegistry& ProfilingService::GetCounterRegistry()
+{
+    return m_CounterDirectory;
+}
+
 ProfilingState ProfilingService::GetCurrentState() const
 {
     return m_StateMachine.GetCurrentState();
@@ -214,7 +219,7 @@
     return m_CounterIdMap;
 }
 
-IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistrar()
+IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
 {
     return m_CounterIdMap;
 }
@@ -381,6 +386,7 @@
     m_CounterIndex.clear();
     m_CounterValues.clear();
     m_CounterDirectory.Clear();
+    m_CounterIdMap.Reset();
     m_BufferManager.Reset();
 
     // ...finally reset the profiling state machine
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 9cf7545..e510589 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -10,6 +10,7 @@
 #include "ConnectionAcknowledgedCommandHandler.hpp"
 #include "CounterDirectory.hpp"
 #include "CounterIdMap.hpp"
+#include "ICounterRegistry.hpp"
 #include "ICounterValues.hpp"
 #include "PeriodicCounterCapture.hpp"
 #include "PeriodicCounterSelectionCommandHandler.hpp"
@@ -63,13 +64,14 @@
     void Disconnect();
 
     const ICounterDirectory& GetCounterDirectory() const;
+    ICounterRegistry& GetCounterRegistry();
     ProfilingState GetCurrentState() const;
     bool IsCounterRegistered(uint16_t counterUid) const override;
     uint32_t GetCounterValue(uint16_t counterUid) const override;
     uint16_t GetCounterCount() const override;
     // counter global/backend mapping functions
     const ICounterMappings& GetCounterMappings() const;
-    IRegisterCounterMapping& GetCounterMappingRegistrar();
+    IRegisterCounterMapping& GetCounterMappingRegistry();
 
     // Getters for the profiling service state
     bool IsProfilingEnabled();
diff --git a/src/profiling/RegisterBackendCounters.cpp b/src/profiling/RegisterBackendCounters.cpp
new file mode 100644
index 0000000..0c68838
--- /dev/null
+++ b/src/profiling/RegisterBackendCounters.cpp
@@ -0,0 +1,87 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RegisterBackendCounters.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+void RegisterBackendCounters::RegisterCategory(const std::string& categoryName,
+                                               const Optional<uint16_t>& deviceUid,
+                                               const Optional<uint16_t>& counterSetUid)
+{
+     m_CounterDirectory.RegisterCategory(categoryName, deviceUid, counterSetUid);
+}
+
+uint16_t RegisterBackendCounters::RegisterDevice(const std::string& deviceName,
+                                                 uint16_t cores,
+                                                 const Optional<std::string>& parentCategoryName)
+{
+    const Device* devicePtr = m_CounterDirectory.RegisterDevice(deviceName, cores, parentCategoryName);
+    return devicePtr->m_Uid;
+}
+
+uint16_t RegisterBackendCounters::RegisterCounterSet(const std::string& counterSetName,
+                                                     uint16_t count,
+                                                     const Optional<std::string>& parentCategoryName)
+{
+    const CounterSet* counterSetPtr = m_CounterDirectory.RegisterCounterSet(counterSetName, count, parentCategoryName);
+    return counterSetPtr->m_Uid;
+}
+
+uint16_t RegisterBackendCounters::RegisterCounter(const uint16_t uid,
+                                                  const std::string& parentCategoryName,
+                                                  uint16_t counterClass,
+                                                  uint16_t interpolation,
+                                                  double multiplier,
+                                                  const std::string& name,
+                                                  const std::string& description,
+                                                  const Optional<std::string>& units,
+                                                  const Optional<uint16_t>& numberOfCores,
+                                                  const Optional<uint16_t>& deviceUid,
+                                                  const Optional<uint16_t>& counterSetUid)
+{
+    ++m_CurrentMaxGlobalCounterID;
+    const Counter* counterPtr = m_CounterDirectory.RegisterCounter(m_BackendId,
+                                                                   m_CurrentMaxGlobalCounterID,
+                                                                   parentCategoryName,
+                                                                   counterClass,
+                                                                   interpolation,
+                                                                   multiplier,
+                                                                   name,
+                                                                   description,
+                                                                   units,
+                                                                   numberOfCores,
+                                                                   deviceUid,
+                                                                   counterSetUid);
+    m_CurrentMaxGlobalCounterID = counterPtr->m_MaxCounterUid;
+    // register mappings
+    IRegisterCounterMapping& counterIdMap = ProfilingService::Instance().GetCounterMappingRegistry();
+    uint16_t globalCounterId = counterPtr->m_Uid;
+    if (globalCounterId == counterPtr->m_MaxCounterUid)
+    {
+        counterIdMap.RegisterMapping(globalCounterId, uid, m_BackendId);
+    }
+    else
+    {
+        uint16_t backendCounterId = uid;
+        while (globalCounterId <= counterPtr->m_MaxCounterUid)
+        {
+            // register mapping
+            // globalCounterId -> backendCounterId, m_BackendId
+            counterIdMap.RegisterMapping(globalCounterId, backendCounterId, m_BackendId);
+            ++globalCounterId;
+            ++backendCounterId;
+        }
+    }
+    return m_CurrentMaxGlobalCounterID;
+}
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/RegisterBackendCounters.hpp b/src/profiling/RegisterBackendCounters.hpp
new file mode 100644
index 0000000..41886c0
--- /dev/null
+++ b/src/profiling/RegisterBackendCounters.hpp
@@ -0,0 +1,62 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "armnn/backends/profiling/IBackendProfiling.hpp"
+#include "CounterIdMap.hpp"
+#include "CounterDirectory.hpp"
+#include "ProfilingService.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class RegisterBackendCounters : public IRegisterBackendCounters
+{
+public:
+
+    RegisterBackendCounters(uint16_t currentMaxGlobalCounterID, const BackendId& backendId)
+                            : m_CurrentMaxGlobalCounterID(currentMaxGlobalCounterID),
+                              m_CounterDirectory(ProfilingService::Instance().GetCounterRegistry()),
+                              m_BackendId(backendId) {}
+
+    ~RegisterBackendCounters() = default;
+
+    void RegisterCategory(const std::string& categoryName,
+                          const Optional<uint16_t>& deviceUid     = EmptyOptional(),
+                          const Optional<uint16_t>& counterSetUid = EmptyOptional()) override;
+
+    uint16_t RegisterDevice(const std::string& deviceName,
+                            uint16_t cores = 0,
+                            const Optional<std::string>& parentCategoryName = EmptyOptional()) override;
+
+    uint16_t RegisterCounterSet(const std::string& counterSetName,
+                                uint16_t count = 0,
+                                const Optional<std::string>& parentCategoryName = EmptyOptional()) override;
+
+    uint16_t RegisterCounter(const uint16_t uid,
+                             const std::string& parentCategoryName,
+                             uint16_t counterClass,
+                             uint16_t interpolation,
+                             double multiplier,
+                             const std::string& name,
+                             const std::string& description,
+                             const Optional<std::string>& units      = EmptyOptional(),
+                             const Optional<uint16_t>& numberOfCores = EmptyOptional(),
+                             const Optional<uint16_t>& deviceUid     = EmptyOptional(),
+                             const Optional<uint16_t>& counterSetUid = EmptyOptional()) override;
+
+private:
+    uint16_t m_CurrentMaxGlobalCounterID;
+    ICounterRegistry& m_CounterDirectory;
+    const BackendId& m_BackendId;
+};
+
+} // namespace profiling
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index e127a18..d06201d 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -40,6 +40,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <RegisterBackendCounters.hpp>
 
 using namespace armnn::profiling;
 using PacketType = MockProfilingConnection::PacketType;
@@ -3168,6 +3169,7 @@
         BOOST_FAIL("Expected string not found.");
     }
 }
+
 BOOST_AUTO_TEST_CASE(CheckCounterIdMap)
 {
     CounterIdMap counterIdMap;
@@ -3208,4 +3210,53 @@
     BOOST_CHECK(counterIdMap.GetGlobalId(1, cpuAccId) == 5);
 }
 
+BOOST_AUTO_TEST_CASE(CheckRegisterBackendCounters)
+{
+    uint16_t globalCounterIds = armnn::profiling::INFERENCES_RUN;
+    armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
+
+    RegisterBackendCounters registerBackendCounters(globalCounterIds, cpuRefId);
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling          = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    BOOST_CHECK(profilingService.GetCounterDirectory().GetCategories().empty());
+    registerBackendCounters.RegisterCategory("categoryOne");
+    auto categoryOnePtr = profilingService.GetCounterDirectory().GetCategory("categoryOne");
+    BOOST_CHECK(categoryOnePtr);
+
+    BOOST_CHECK(profilingService.GetCounterDirectory().GetDevices().empty());
+    globalCounterIds = registerBackendCounters.RegisterDevice("deviceOne");
+    auto deviceOnePtr = profilingService.GetCounterDirectory().GetDevice(globalCounterIds);
+    BOOST_CHECK(deviceOnePtr);
+    BOOST_CHECK(deviceOnePtr->m_Name == "deviceOne");
+
+    BOOST_CHECK(profilingService.GetCounterDirectory().GetCounterSets().empty());
+    globalCounterIds = registerBackendCounters.RegisterCounterSet("counterSetOne");
+    auto counterSetOnePtr = profilingService.GetCounterDirectory().GetCounterSet(globalCounterIds);
+    BOOST_CHECK(counterSetOnePtr);
+    BOOST_CHECK(counterSetOnePtr->m_Name == "counterSetOne");
+
+    uint16_t newGlobalCounterId = registerBackendCounters.RegisterCounter(0,
+                                                                          "categoryOne",
+                                                                          0,
+                                                                          0,
+                                                                          1.f,
+                                                                          "CounterOne",
+                                                                          "first test counter");
+    BOOST_CHECK(newGlobalCounterId = armnn::profiling::INFERENCES_RUN + 1);
+    uint16_t mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuRefId);
+    BOOST_CHECK(mappedGlobalId == newGlobalCounterId);
+    auto backendMapping = profilingService.GetCounterMappings().GetBackendId(newGlobalCounterId);
+    BOOST_CHECK(backendMapping.first == 0);
+    BOOST_CHECK(backendMapping.second == cpuRefId);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
 BOOST_AUTO_TEST_SUITE_END()