IVGCVSW-5813 Add Async Queue to IRuntime

Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 46eb988..67de00f 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -24,6 +24,7 @@
 #include <LabelsAndEventClasses.hpp>
 
 #include <fmt/format.h>
+#include <armnn/utility/Timer.hpp>
 
 namespace armnn
 {
@@ -84,7 +85,8 @@
 std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
                                                                 std::string& errorMessage,
                                                                 const INetworkProperties& networkProperties,
-                                                                profiling::ProfilingService&  profilingService)
+                                                                profiling::ProfilingService&  profilingService,
+                                                                const NetworkId networkIdOut)
 {
     std::unique_ptr<LoadedNetwork> loadedNetwork;
 
@@ -98,7 +100,7 @@
 
     try
     {
-        loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
+        loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut));
     }
     catch (const armnn::RuntimeException& error)
     {
@@ -118,9 +120,11 @@
 
 LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
                              const INetworkProperties& networkProperties,
-                             profiling::ProfilingService&  profilingService) :
+                             profiling::ProfilingService&  profilingService,
+                             const NetworkId networkId) :
                              m_OptimizedNetwork(std::move(net)),
                              m_NetworkProperties(networkProperties),
+                             m_NetworkId(networkId),
                              m_TensorHandleFactoryRegistry(),
                              m_ProfilingService(profilingService)
 {
@@ -161,6 +165,14 @@
             }
         }
     }
+
+    // Create the thread pool which will have working memory handles assigned to each thread
+    // Should occur after factories are registered so thet the WorkingMemHandles can be created
+    if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled)
+    {
+        CreateThreadPool(m_NetworkProperties.m_NumThreads);
+    }
+
     if (!networkProperties.m_AsyncEnabled)
     {
         for (auto &&layer : order)
@@ -846,6 +858,147 @@
     return success;
 }
 
+void LoadedNetwork::CreateThreadPool(std::size_t numThreads)
+{
+
+    for (auto i = 0u; i < numThreads; ++i)
+    {
+        std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId);
+        m_Threads.emplace_back(
+            std::make_unique<std::thread>(
+                &LoadedNetwork::ProcessExecPriorities,
+                this,
+                std::move(workingMemHandle)
+            )
+        );
+    }
+}
+
+void LoadedNetwork::TerminateThreadPool() noexcept
+{
+    {
+        std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
+        m_TerminatePool = true;
+    }
+
+    m_ThreadPoolEvent.notify_all();
+
+    for (auto &thread : m_Threads)
+    {
+        thread->join();
+    }
+}
+
+void LoadedNetwork::Schedule(const InputTensors& inputTensors,
+                             const OutputTensors& outputTensors,
+                             const QosExecPriority priority,
+                             std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+    // Group execution parameters so that they can be easily added to the queue
+    ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb);
+    std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams);
+
+    // Add a message to the queue and notify the request thread
+    std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+    switch (priority) {
+        case QosExecPriority::High:
+            m_HighPriorityQueue.push(operation);
+            break;
+        case QosExecPriority::Low:
+            m_LowPriorityQueue.push(operation);
+            break;
+        case QosExecPriority::Medium:
+        default:
+            m_MediumPriorityQueue.push(operation);
+    }
+    m_ThreadPoolEvent.notify_one();
+}
+
+void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle)
+{
+    int expireRate          = EXPIRE_RATE;
+    int highPriorityCount   = 0;
+    int mediumPriorityCount = 0;
+
+    IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
+
+    while (true)
+    {
+        std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
+        {
+            // Wait for a message to be added to the queue
+            // This is in a separate scope to minimise the lifetime of the lock
+            std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+
+            m_ThreadPoolEvent.wait(lock,
+                                   [=] {
+                                       return m_TerminatePool || !m_HighPriorityQueue.empty() ||
+                                              !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
+                                   });
+
+            if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
+                m_LowPriorityQueue.empty())
+            {
+                break;
+            }
+
+            // Get the message to process from the front of each queue based on priority from high to low
+            // Get high priority first if it does not exceed the expire rate
+            if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
+            {
+                currentExecInProgress = m_HighPriorityQueue.front();
+                m_HighPriorityQueue.pop();
+                highPriorityCount += 1;
+            }
+            // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
+            else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
+            {
+                currentExecInProgress = m_MediumPriorityQueue.front();
+                m_MediumPriorityQueue.pop();
+                mediumPriorityCount += 1;
+                // Reset high priority count
+                highPriorityCount     = 0;
+            }
+            // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
+            else if (!m_LowPriorityQueue.empty())
+            {
+                currentExecInProgress = m_LowPriorityQueue.front();
+                m_LowPriorityQueue.pop();
+                // Reset high and medium priority count
+                highPriorityCount   = 0;
+                mediumPriorityCount = 0;
+            }
+            else
+            {
+                // Reset high and medium priority count
+                highPriorityCount   = 0;
+                mediumPriorityCount = 0;
+                continue;
+            }
+        }
+
+        // invoke the asynchronous execution method
+        auto inputTensors  = std::get<0>(*currentExecInProgress);
+        auto outputTensors = std::get<1>(*currentExecInProgress);
+        auto cb            = std::get<2>(*currentExecInProgress);
+
+        // Get time at start of inference
+        HighResolutionClock startTime = armnn::GetTimeNow();
+
+        try // executing the inference
+        {
+            // Execute and populate the time at end of inference in the callback
+            Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ?
+                cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
+                cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+        }
+        catch (const RuntimeException& error)
+        {
+            cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+        }
+    }
+}
+
 void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
                                  const ConstTensor& inputTensor,
                                  WorkingMemHandle& context)
@@ -1096,6 +1249,7 @@
             EnqueueOutput(*outputLayer, GetOutputTensor(outputLayer->GetBindingId(), outputTensors), workingMemHandle);
         }
     }
+
     return executionSucceeded ? Status::Success : Status::Failure;
 }