IVGCVSW-5781 Add Async Support to Android-NN-Driver

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 60beac4..978f378 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -112,16 +112,23 @@
                                                    armnn::IRuntime* runtime,
                                                    const HalModel& model,
                                                    const std::string& requestInputsAndOutputsDumpDir,
-                                                   const bool gpuProfilingEnabled)
+                                                   const bool gpuProfilingEnabled,
+                                                   const bool asyncModelExecutionEnabled)
     : m_NetworkId(networkId)
     , m_Runtime(runtime)
     , m_Model(model)
     , m_RequestCount(0)
     , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
     , m_GpuProfilingEnabled(gpuProfilingEnabled)
+    , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
 {
     // Enable profiling if required.
     m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+    if (asyncModelExecutionEnabled)
+    {
+        m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+    }
 }
 
 template<typename HalVersion>
@@ -225,8 +232,6 @@
         return V1_0::ErrorStatus::GENERAL_FAILURE;
     }
 
-    ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
-
     auto cb = [callback](V1_0::ErrorStatus errorStatus, std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
@@ -234,7 +239,17 @@
 
     CallbackContext_1_0 armnnCb;
     armnnCb.callback = cb;
+
+    if (m_AsyncModelExecutionEnabled)
+    {
+        ALOGV("ArmnnPreparedModel::execute(...) before ScheduleGraphForExecution");
+        ScheduleGraphForExecution(pMemPools, pInputTensors, pOutputTensors, armnnCb);
+        ALOGV("ArmnnPreparedModel::execute(...) after ScheduleGraphForExecution");
+        return V1_0::ErrorStatus::NONE;
+    }
+
     // post the request for asynchronous execution
+    ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
     m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
     ALOGV("ArmnnPreparedModel::execute(...) after PostMsg");
     return V1_0::ErrorStatus::NONE; // successfully queued
@@ -254,7 +269,18 @@
     // run it
     try
     {
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        armnn::Status status;
+        if (m_AsyncModelExecutionEnabled)
+        {
+            ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled true");
+            status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+        }
+        else
+        {
+            ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false");
+            status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        }
+
         if (status != armnn::Status::Success)
         {
             ALOGW("EnqueueWorkload failed");
@@ -340,11 +366,73 @@
     return true;
 }
 
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ScheduleGraphForExecution(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext callbackContext)
+{
+    ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution(...)");
+
+    DumpTensorsIfRequired("Input", *inputTensors);
+
+
+    auto tpCb = std::make_shared<
+                ArmnnThreadPoolCallback<CallbackContext_1_0>>(this,
+                                                              pMemPools,
+                                                              inputTensors,
+                                                              outputTensors,
+                                                              callbackContext);
+
+    m_Runtime->Schedule(m_NetworkId,
+                        *tpCb->m_InputTensors,
+                        *tpCb->m_OutputTensors,
+                        armnn::QosExecPriority::High,
+                        tpCb);
+    ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution end");
+}
+
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ArmnnThreadPoolCallback<CallbackContext>::Notify(
+        armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+    armnn::IgnoreUnused(status, timeTaken);
+    ALOGV("ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+
+    m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+    // Commit output buffers.
+    // Note that we update *all* pools, even if they aren't actually used as outputs -
+    // this is simpler and is what the CpuExecutor does.
+    for (android::nn::RunTimePoolInfo& pool : *m_MemPools)
+    {
+        // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
+        // update() has been removed and flush() added.
+        #if defined(ARMNN_ANDROID_R) || defined(ARMNN_ANDROID_S) // Use the new Android implementation.
+            pool.flush();
+        #else
+            pool.update();
+        #endif
+    }
+
+    m_CallbackContext.callback(V1_0::ErrorStatus::NONE, "ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+    return;
+}
+
 ///
 /// Class template specializations
 ///
 
 template class ArmnnPreparedModel<hal_1_0::HalPolicy>;
+template void ArmnnPreparedModel<hal_1_0::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_0>(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext_1_0 callbackContext);
 
 #ifdef ARMNN_ANDROID_NN_V1_1
 template class ArmnnPreparedModel<hal_1_1::HalPolicy>;