IVGCVSW-5781 Add Async Support to Android-NN-Driver

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
diff --git a/1.2/ArmnnDriverImpl.cpp b/1.2/ArmnnDriverImpl.cpp
index ccf82d0..3eae526 100644
--- a/1.2/ArmnnDriverImpl.cpp
+++ b/1.2/ArmnnDriverImpl.cpp
@@ -188,9 +188,14 @@
 
     // Load it into the runtime.
     armnn::NetworkId netId = 0;
+    std::string msg;
+    armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
+                                                MemorySource::Undefined,
+                                                MemorySource::Undefined,
+                                                options.getNoOfArmnnThreads());
     try
     {
-        if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
+        if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
         {
             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
         }
@@ -216,7 +221,8 @@
                     runtime.get(),
                     model,
                     options.GetRequestInputsAndOutputsDumpDir(),
-                    options.IsGpuProfilingEnabled()));
+                    options.IsGpuProfilingEnabled(),
+                    options.isAsyncModelExecutionEnabled()));
 
     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
diff --git a/1.3/ArmnnDriverImpl.cpp b/1.3/ArmnnDriverImpl.cpp
index 6d8fbe6..5c5e607 100644
--- a/1.3/ArmnnDriverImpl.cpp
+++ b/1.3/ArmnnDriverImpl.cpp
@@ -199,9 +199,14 @@
 
     // Load it into the runtime.
     armnn::NetworkId netId = 0;
+    std::string msg;
+    armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
+                                                MemorySource::Undefined,
+                                                MemorySource::Undefined,
+                                                options.getNoOfArmnnThreads());
     try
     {
-        if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
+        if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
         {
             return FailPrepareModel(V1_3::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
         }
@@ -228,7 +233,8 @@
                     model,
                     options.GetRequestInputsAndOutputsDumpDir(),
                     options.IsGpuProfilingEnabled(),
-                    priority));
+                    priority,
+                    options.isAsyncModelExecutionEnabled()));
 
     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
diff --git a/ArmnnDriverImpl.cpp b/ArmnnDriverImpl.cpp
index 3e4aab3..0e6e8b1 100644
--- a/ArmnnDriverImpl.cpp
+++ b/ArmnnDriverImpl.cpp
@@ -163,9 +163,15 @@
 
     // Load it into the runtime.
     armnn::NetworkId netId = 0;
+    std::string msg;
+    armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
+                                                armnn::MemorySource::Undefined,
+                                                armnn::MemorySource::Undefined,
+                                                options.getNoOfArmnnThreads());
+
     try
     {
-        if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
+        if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
         {
             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
         }
@@ -191,7 +197,8 @@
                     runtime.get(),
                     model,
                     options.GetRequestInputsAndOutputsDumpDir(),
-                    options.IsGpuProfilingEnabled()));
+                    options.IsGpuProfilingEnabled(),
+                    options.isAsyncModelExecutionEnabled()));
 
     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 60beac4..978f378 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -112,16 +112,23 @@
                                                    armnn::IRuntime* runtime,
                                                    const HalModel& model,
                                                    const std::string& requestInputsAndOutputsDumpDir,
-                                                   const bool gpuProfilingEnabled)
+                                                   const bool gpuProfilingEnabled,
+                                                   const bool asyncModelExecutionEnabled)
     : m_NetworkId(networkId)
     , m_Runtime(runtime)
     , m_Model(model)
     , m_RequestCount(0)
     , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
     , m_GpuProfilingEnabled(gpuProfilingEnabled)
+    , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
 {
     // Enable profiling if required.
     m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+    if (asyncModelExecutionEnabled)
+    {
+        m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+    }
 }
 
 template<typename HalVersion>
@@ -225,8 +232,6 @@
         return V1_0::ErrorStatus::GENERAL_FAILURE;
     }
 
-    ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
-
     auto cb = [callback](V1_0::ErrorStatus errorStatus, std::string callingFunction)
     {
         NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
@@ -234,7 +239,17 @@
 
     CallbackContext_1_0 armnnCb;
     armnnCb.callback = cb;
+
+    if (m_AsyncModelExecutionEnabled)
+    {
+        ALOGV("ArmnnPreparedModel::execute(...) before ScheduleGraphForExecution");
+        ScheduleGraphForExecution(pMemPools, pInputTensors, pOutputTensors, armnnCb);
+        ALOGV("ArmnnPreparedModel::execute(...) after ScheduleGraphForExecution");
+        return V1_0::ErrorStatus::NONE;
+    }
+
     // post the request for asynchronous execution
+    ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
     m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
     ALOGV("ArmnnPreparedModel::execute(...) after PostMsg");
     return V1_0::ErrorStatus::NONE; // successfully queued
@@ -254,7 +269,18 @@
     // run it
     try
     {
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        armnn::Status status;
+        if (m_AsyncModelExecutionEnabled)
+        {
+            ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled true");
+            status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+        }
+        else
+        {
+            ALOGW("ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false");
+            status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        }
+
         if (status != armnn::Status::Success)
         {
             ALOGW("EnqueueWorkload failed");
@@ -340,11 +366,73 @@
     return true;
 }
 
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ScheduleGraphForExecution(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext callbackContext)
+{
+    ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution(...)");
+
+    DumpTensorsIfRequired("Input", *inputTensors);
+
+
+    auto tpCb = std::make_shared<
+                ArmnnThreadPoolCallback<CallbackContext_1_0>>(this,
+                                                              pMemPools,
+                                                              inputTensors,
+                                                              outputTensors,
+                                                              callbackContext);
+
+    m_Runtime->Schedule(m_NetworkId,
+                        *tpCb->m_InputTensors,
+                        *tpCb->m_OutputTensors,
+                        armnn::QosExecPriority::High,
+                        tpCb);
+    ALOGV("ArmnnPreparedModel::ScheduleGraphForExecution end");
+}
+
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel<HalVersion>::ArmnnThreadPoolCallback<CallbackContext>::Notify(
+        armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+    armnn::IgnoreUnused(status, timeTaken);
+    ALOGV("ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+
+    m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+    // Commit output buffers.
+    // Note that we update *all* pools, even if they aren't actually used as outputs -
+    // this is simpler and is what the CpuExecutor does.
+    for (android::nn::RunTimePoolInfo& pool : *m_MemPools)
+    {
+        // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
+        // update() has been removed and flush() added.
+        #if defined(ARMNN_ANDROID_R) || defined(ARMNN_ANDROID_S) // Use the new Android implementation.
+            pool.flush();
+        #else
+            pool.update();
+        #endif
+    }
+
+    m_CallbackContext.callback(V1_0::ErrorStatus::NONE, "ArmnnPreparedModel::ArmnnThreadPoolCallback_1_2 Notify");
+    return;
+}
+
 ///
 /// Class template specializations
 ///
 
 template class ArmnnPreparedModel<hal_1_0::HalPolicy>;
+template void ArmnnPreparedModel<hal_1_0::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_0>(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext_1_0 callbackContext);
 
 #ifdef ARMNN_ANDROID_NN_V1_1
 template class ArmnnPreparedModel<hal_1_1::HalPolicy>;
diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp
index 89f6226..d1c830d 100644
--- a/ArmnnPreparedModel.hpp
+++ b/ArmnnPreparedModel.hpp
@@ -38,7 +38,8 @@
                        armnn::IRuntime* runtime,
                        const HalModel& model,
                        const std::string& requestInputsAndOutputsDumpDir,
-                       const bool gpuProfilingEnabled);
+                       const bool gpuProfilingEnabled,
+                       const bool asyncModelExecutionEnabled = false);
 
     virtual ~ArmnnPreparedModel();
 
@@ -56,9 +57,65 @@
     bool ExecuteWithDummyInputs();
 
 private:
+
+    template<typename CallbackContext>
+    class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
+    {
+    public:
+        ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
+                                std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                                std::shared_ptr<armnn::InputTensors>& inputTensors,
+                                std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                                CallbackContext callbackContext) :
+                m_Model(model),
+                m_MemPools(pMemPools),
+                m_InputTensors(inputTensors),
+                m_OutputTensors(outputTensors),
+                m_CallbackContext(callbackContext)
+        {}
+
+        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
+
+        // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
+        virtual armnn::Status GetStatus() const override
+        {
+            return armnn::Status::Success;
+        }
+
+        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+        virtual void Wait() const override
+        {}
+
+        // Retrieve the start time before executing the inference
+        virtual armnn::HighResolutionClock GetStartTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        // Retrieve the time after executing the inference
+        virtual armnn::HighResolutionClock GetEndTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        ArmnnPreparedModel<HalVersion>* m_Model;
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
+        std::shared_ptr<armnn::InputTensors> m_InputTensors;
+        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
+        CallbackContext m_CallbackContext;
+    };
+
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
 
+    /// schedule the graph prepared from the request for execution
+    template<typename CallbackContext>
+    void ScheduleGraphForExecution(
+            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+            std::shared_ptr<armnn::InputTensors>& inputTensors,
+            std::shared_ptr<armnn::OutputTensors>& outputTensors,
+            CallbackContext m_CallbackContext);
+
     armnn::NetworkId                                                        m_NetworkId;
     armnn::IRuntime*                                                        m_Runtime;
     HalModel                                                                m_Model;
@@ -68,6 +125,9 @@
     uint32_t                                                                m_RequestCount;
     const std::string&                                                      m_RequestInputsAndOutputsDumpDir;
     const bool                                                              m_GpuProfilingEnabled;
+
+    std::unique_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
+    const bool m_AsyncModelExecutionEnabled;
 };
 
 }
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index a2148c2..c129fd6 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -6,6 +6,7 @@
 #define LOG_TAG "ArmnnDriver"
 
 #include "ArmnnPreparedModel_1_2.hpp"
+
 #include "Utils.hpp"
 
 #include <log/log.h>
@@ -146,16 +147,23 @@
                                                            armnn::IRuntime* runtime,
                                                            const V1_2::Model& model,
                                                            const std::string& requestInputsAndOutputsDumpDir,
-                                                           const bool gpuProfilingEnabled)
+                                                           const bool gpuProfilingEnabled,
+                                                           const bool asyncModelExecutionEnabled)
     : m_NetworkId(networkId)
     , m_Runtime(runtime)
     , m_Model(model)
     , m_RequestCount(0)
     , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
     , m_GpuProfilingEnabled(gpuProfilingEnabled)
+    , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
 {
     // Enable profiling if required.
     m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+    if (asyncModelExecutionEnabled)
+    {
+        m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+    }
 }
 
 template<typename HalVersion>
@@ -440,7 +448,17 @@
             deviceStart = Now();
         }
 
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        armnn::Status status;
+        if (m_AsyncModelExecutionEnabled)
+        {
+            ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled true");
+            status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+        }
+        else
+        {
+            ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled false");
+            status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        }
 
         if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
@@ -567,12 +585,21 @@
         {}
     }
 
-    ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
 
     // post the request for asynchronous execution
     CallbackContext_1_2 cb;
     cb.callback = callback;
     cb.ctx = ctx;
+
+    if (m_AsyncModelExecutionEnabled)
+    {
+        ALOGV("ArmnnPreparedModel_1_2::execute(...) before ScheduleGraphForExecution");
+        ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb);
+        ALOGV("ArmnnPreparedModel_1_2::execute(...) after ScheduleGraphForExecution");
+        return V1_0::ErrorStatus::NONE;
+    }
+
+    ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
     m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
     ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
     return V1_0::ErrorStatus::NONE;
@@ -602,6 +629,84 @@
     return Void();
 }
 
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel_1_2<HalVersion>::ScheduleGraphForExecution(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext callbackContext)
+{
+    ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution(...)");
+
+    DumpTensorsIfRequired("Input", *inputTensors);
+
+    unsigned int outputTensorSize = outputTensors.get()->size();
+    std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
+    for (unsigned int i = 0; i < outputTensorSize; i++)
+    {
+        std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
+        const armnn::Tensor outputTensor = outputTensorPair.second;
+        const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
+
+        outputShapes[i] = ComputeShape(outputTensorInfo);
+    }
+
+    auto tpCb = std::make_shared<
+        ArmnnThreadPoolCallback_1_2<CallbackContext_1_2>>(this,
+                                                          pMemPools,
+                                                          outputShapes,
+                                                          inputTensors,
+                                                          outputTensors,
+                                                          callbackContext);
+
+    m_Runtime->Schedule(m_NetworkId,
+                        *tpCb->m_InputTensors,
+                        *tpCb->m_OutputTensors,
+                        armnn::QosExecPriority::High,
+                        tpCb);
+    ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution end");
+}
+
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel_1_2<HalVersion>::ArmnnThreadPoolCallback_1_2<CallbackContext>::Notify(
+        armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+    ALOGV("ArmnnPreparedModel_1_2::ArmnnThreadPoolCallback_1_2 Notify");
+
+    TimePoint driverEnd;
+
+    CommitPools(*m_MemPools);
+
+    m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+    if (status != armnn::Status::Success)
+    {
+        ALOGW("ArmnnThreadPoolCallback::Notify EnqueueWorkload failed");
+        m_CallbackContext.callback(
+                V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel::ExecuteGraph");
+        return;
+    }
+
+    if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
+    {
+        driverEnd = std::chrono::steady_clock::now();
+        V1_2::Timing timing;
+        timing.timeOnDevice = MicrosecondsDuration(timeTaken.second, timeTaken.first);
+        timing.timeInDriver = MicrosecondsDuration(driverEnd, m_CallbackContext.ctx.driverStart);
+        ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+              timing.timeInDriver);
+        m_CallbackContext.callback(
+                V1_0::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
+    } else {
+        m_CallbackContext.callback(
+                V1_0::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
+    }
+    return;
+}
+
 #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3)
 template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
 template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
@@ -609,6 +714,12 @@
         armnn::InputTensors& pInputTensors,
         armnn::OutputTensors& pOutputTensors,
         CallbackContext_1_2 cb);
+
+template void ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_2>(
+                std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                std::shared_ptr<armnn::InputTensors>& inputTensors,
+                std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                CallbackContext_1_2 callbackContext);
 #endif
 
 } // namespace armnn_driver
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index 13d7494..6c630c5 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -44,7 +44,8 @@
                            armnn::IRuntime* runtime,
                            const HalModel& model,
                            const std::string& requestInputsAndOutputsDumpDir,
-                           const bool gpuProfilingEnabled);
+                           const bool gpuProfilingEnabled,
+                           const bool asyncModelExecutionEnabled = false);
 
     virtual ~ArmnnPreparedModel_1_2();
 
@@ -76,6 +77,57 @@
     bool ExecuteWithDummyInputs();
 
 private:
+
+    template<typename CallbackContext>
+    class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
+    {
+    public:
+        ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
+                                    std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                                    std::vector<V1_2::OutputShape> outputShapes,
+                                    std::shared_ptr<armnn::InputTensors>& inputTensors,
+                                    std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                                    CallbackContext callbackContext) :
+                m_Model(model),
+                m_MemPools(pMemPools),
+                m_OutputShapes(outputShapes),
+                m_InputTensors(inputTensors),
+                m_OutputTensors(outputTensors),
+                m_CallbackContext(callbackContext)
+        {}
+
+        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
+
+        // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
+        virtual armnn::Status GetStatus() const override
+        {
+            return armnn::Status::Success;
+        }
+
+        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+        virtual void Wait() const override
+        {}
+
+        // Retrieve the start time before executing the inference
+        virtual armnn::HighResolutionClock GetStartTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        // Retrieve the time after executing the inference
+        virtual armnn::HighResolutionClock GetEndTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        ArmnnPreparedModel_1_2<HalVersion>* m_Model;
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
+        std::vector<V1_2::OutputShape> m_OutputShapes;
+        std::shared_ptr<armnn::InputTensors> m_InputTensors;
+        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
+        CallbackContext m_CallbackContext;
+    };
+
     Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
                                       V1_2::MeasureTiming measureTiming,
                                       CallbackAsync_1_2 callback);
@@ -101,6 +153,14 @@
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
 
+    /// schedule the graph prepared from the request for execution
+    template<typename CallbackContext>
+    void ScheduleGraphForExecution(
+            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+            std::shared_ptr<armnn::InputTensors>& inputTensors,
+            std::shared_ptr<armnn::OutputTensors>& outputTensors,
+            CallbackContext m_CallbackContext);
+
     armnn::NetworkId                                                            m_NetworkId;
     armnn::IRuntime*                                                            m_Runtime;
     V1_2::Model                                                                 m_Model;
@@ -112,6 +172,9 @@
     uint32_t                                                                    m_RequestCount;
     const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
     const bool                                                                  m_GpuProfilingEnabled;
+
+    std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
+    const bool m_AsyncModelExecutionEnabled;
 };
 
 }
diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp
index 3d93b99..5a37032 100644
--- a/ArmnnPreparedModel_1_3.cpp
+++ b/ArmnnPreparedModel_1_3.cpp
@@ -168,7 +168,8 @@
                                                            const V1_3::Model& model,
                                                            const std::string& requestInputsAndOutputsDumpDir,
                                                            const bool gpuProfilingEnabled,
-                                                           V1_3::Priority priority)
+                                                           V1_3::Priority priority,
+                                                           const bool asyncModelExecutionEnabled)
     : m_NetworkId(networkId)
     , m_Runtime(runtime)
     , m_Model(model)
@@ -176,9 +177,15 @@
     , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
     , m_GpuProfilingEnabled(gpuProfilingEnabled)
     , m_ModelPriority(priority)
+    , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
 {
     // Enable profiling if required.
     m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
+
+    if (asyncModelExecutionEnabled)
+    {
+        m_WorkingMemHandle = m_Runtime->CreateWorkingMemHandle(networkId);
+    }
 }
 
 template<typename HalVersion>
@@ -726,8 +733,17 @@
         {
             cb.ctx.deviceStart = Now();
         }
-
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        armnn::Status status;
+        if (m_AsyncModelExecutionEnabled)
+        {
+            ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled true");
+            status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
+        }
+        else
+        {
+            ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled false");
+            status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
+        }
 
         if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
         {
@@ -735,7 +751,7 @@
         }
         if (status != armnn::Status::Success)
         {
-            ALOGW("EnqueueWorkload failed");
+            ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph EnqueueWorkload failed");
             cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
             return V1_3::ErrorStatus::GENERAL_FAILURE;
         }
@@ -773,6 +789,47 @@
     return V1_3::ErrorStatus::NONE;
 }
 
+/// Schedule the graph prepared from the request for execution
+template<typename HalVersion>
+template<typename CallbackContext>
+void ArmnnPreparedModel_1_3<HalVersion>::ScheduleGraphForExecution(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        std::shared_ptr<armnn::InputTensors>& inputTensors,
+        std::shared_ptr<armnn::OutputTensors>& outputTensors,
+        CallbackContext callbackContext,
+        armnn::QosExecPriority priority)
+{
+    ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution(...)");
+
+    DumpTensorsIfRequired("Input", *inputTensors);
+
+    unsigned int outputTensorSize = outputTensors.get()->size();
+    std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
+    for (unsigned int i = 0; i < outputTensorSize; i++)
+    {
+        std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
+        const armnn::Tensor outputTensor = outputTensorPair.second;
+        const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
+
+        outputShapes[i] = ComputeShape(outputTensorInfo);
+    }
+
+    auto tpCb = std::make_shared<
+        ArmnnThreadPoolCallback_1_3<CallbackContext_1_3>>(this,
+                                                          pMemPools,
+                                                          outputShapes,
+                                                          inputTensors,
+                                                          outputTensors,
+                                                          callbackContext);
+
+    m_Runtime->Schedule(m_NetworkId,
+                        *tpCb->m_InputTensors,
+                        *tpCb->m_OutputTensors,
+                        priority,
+                        tpCb);
+    ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution end");
+}
+
 template<typename HalVersion>
 bool ArmnnPreparedModel_1_3<HalVersion>::ExecuteWithDummyInputs()
 {
@@ -862,13 +919,46 @@
         default:
         {}
     }
-
-    ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg");
-
-    // post the request for asynchronous execution
     CallbackContext_1_3 cb;
     cb.callback = callback;
     cb.ctx = ctx;
+
+
+    enum class QosExecPriority
+    {
+        Low    = 0,
+        Medium = 1,
+        High   = 2
+    };
+
+
+    if (m_AsyncModelExecutionEnabled)
+    {
+        armnn::QosExecPriority priority;
+
+        switch (GetModelPriority()) {
+            case V1_3::Priority::LOW:
+                priority = armnn::QosExecPriority::Low;
+                break;
+            case V1_3::Priority::MEDIUM:
+                priority = armnn::QosExecPriority::Medium;
+                break;
+            case V1_3::Priority::HIGH:
+                priority = armnn::QosExecPriority::High;
+                break;
+            default:
+                priority = armnn::QosExecPriority::Medium;
+
+        }
+
+        ALOGV("ArmnnPreparedModel_1_3::execute(...) before ScheduleGraphForExecution");
+        ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb, priority);
+        ALOGV("ArmnnPreparedModel_1_3::execute(...) after ScheduleGraphForExecution");
+        return V1_3::ErrorStatus::NONE;
+    }
+
+    ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg");
+    // post the request for asynchronous execution
     m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
     ALOGV("ArmnnPreparedModel_1_3::execute(...) after PostMsg");
     return V1_3::ErrorStatus::NONE;
@@ -880,6 +970,46 @@
     return m_ModelPriority;
 }
 
+template<typename HalVersion>
+template <typename CallbackContext>
+void ArmnnPreparedModel_1_3<HalVersion>::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify(
+        armnn::Status status, armnn::InferenceTimingPair timeTaken)
+{
+    ALOGV("ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify");
+    CommitPools(*m_MemPools);
+
+     m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
+
+    if (status != armnn::Status::Success)
+    {
+        ALOGW("ArmnnThreadPoolCallback_1_3::Notify EnqueueWorkload failed");
+        m_CallbackContext.callback(V1_3::ErrorStatus::GENERAL_FAILURE,
+                                   {},
+                                   g_NoTiming,
+                                   "ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3");
+        return;
+    }
+
+    if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
+    {
+        m_CallbackContext.ctx.deviceStart = timeTaken.first;
+        m_CallbackContext.ctx.deviceEnd = timeTaken.second;
+        m_CallbackContext.ctx.driverEnd = std::chrono::steady_clock::now();
+        V1_2::Timing timing;
+        timing.timeOnDevice = MicrosecondsDuration(m_CallbackContext.ctx.deviceEnd, m_CallbackContext.ctx.deviceStart);
+        timing.timeInDriver = MicrosecondsDuration(m_CallbackContext.ctx.driverEnd, m_CallbackContext.ctx.driverStart);
+        ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+              timing.timeInDriver);
+        m_CallbackContext.callback(
+                V1_3::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph");
+    } else
+    {
+        m_CallbackContext.callback(
+                V1_3::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
+    }
+    return;
+}
+
 #ifdef ARMNN_ANDROID_NN_V1_3
 template class ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>;
 template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ExecuteGraph<CallbackContext_1_3>(
@@ -887,6 +1017,13 @@
         armnn::InputTensors& pInputTensors,
         armnn::OutputTensors& pOutputTensors,
         CallbackContext_1_3 cb);
+
+template void ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_3>(
+                std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                std::shared_ptr<armnn::InputTensors>& inputTensors,
+                std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                CallbackContext_1_3 callbackContext,
+                armnn::QosExecPriority priority);
 #endif
 
 } // namespace armnn_driver
diff --git a/ArmnnPreparedModel_1_3.hpp b/ArmnnPreparedModel_1_3.hpp
index c6cdcdc..11299cc 100644
--- a/ArmnnPreparedModel_1_3.hpp
+++ b/ArmnnPreparedModel_1_3.hpp
@@ -51,7 +51,8 @@
                            const HalModel& model,
                            const std::string& requestInputsAndOutputsDumpDir,
                            const bool gpuProfilingEnabled,
-                           V1_3::Priority priority = V1_3::Priority::MEDIUM);
+                           V1_3::Priority priority = V1_3::Priority::MEDIUM,
+                           const bool asyncModelExecutionEnabled = false);
 
     virtual ~ArmnnPreparedModel_1_3();
 
@@ -109,6 +110,57 @@
     V1_3::Priority GetModelPriority();
 
 private:
+
+    template<typename CallbackContext>
+    class ArmnnThreadPoolCallback_1_3 : public armnn::IAsyncExecutionCallback
+    {
+    public:
+        ArmnnThreadPoolCallback_1_3(ArmnnPreparedModel_1_3<HalVersion>* model,
+                                    std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                                    std::vector<V1_2::OutputShape> outputShapes,
+                                    std::shared_ptr<armnn::InputTensors>& inputTensors,
+                                    std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                                    CallbackContext callbackContext) :
+                m_Model(model),
+                m_MemPools(pMemPools),
+                m_OutputShapes(outputShapes),
+                m_InputTensors(inputTensors),
+                m_OutputTensors(outputTensors),
+                m_CallbackContext(callbackContext)
+        {}
+
+        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
+
+        // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
+        virtual armnn::Status GetStatus() const override
+        {
+            return armnn::Status::Success;
+        }
+
+        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+        virtual void Wait() const override
+        {}
+
+        // Retrieve the start time before executing the inference
+        virtual armnn::HighResolutionClock GetStartTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        // Retrieve the time after executing the inference
+        virtual armnn::HighResolutionClock GetEndTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        ArmnnPreparedModel_1_3<HalVersion>* m_Model;
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
+        std::vector<V1_2::OutputShape> m_OutputShapes;
+        std::shared_ptr<armnn::InputTensors> m_InputTensors;
+        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
+        CallbackContext m_CallbackContext;
+    };
+
     Return <V1_3::ErrorStatus> Execute(const V1_3::Request& request,
                                        V1_2::MeasureTiming measureTiming,
                                        CallbackAsync_1_3 callback);
@@ -133,6 +185,15 @@
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
 
+    /// schedule the graph prepared from the request for execution
+    template<typename CallbackContext>
+    void ScheduleGraphForExecution(
+            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+            std::shared_ptr<armnn::InputTensors>& inputTensors,
+            std::shared_ptr<armnn::OutputTensors>& outputTensors,
+            CallbackContext m_CallbackContext,
+            armnn::QosExecPriority priority);
+
     armnn::NetworkId                                                            m_NetworkId;
     armnn::IRuntime*                                                            m_Runtime;
     V1_3::Model                                                                 m_Model;
@@ -143,6 +204,9 @@
     const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
     const bool                                                                  m_GpuProfilingEnabled;
     V1_3::Priority                                                              m_ModelPriority;
+
+    std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
+    const bool m_AsyncModelExecutionEnabled;
 };
 
 }
diff --git a/DriverOptions.cpp b/DriverOptions.cpp
index 42f7ea9..5b67aa3 100644
--- a/DriverOptions.cpp
+++ b/DriverOptions.cpp
@@ -39,6 +39,8 @@
     , m_ShouldExit(false)
     , m_SaveCachedNetwork(false)
     , m_NumberOfThreads(0)
+    , m_EnableAsyncModelExecution(false)
+    , m_ArmnnNumberOfThreads(1)
 {
 }
 
@@ -53,6 +55,8 @@
     , m_ShouldExit(false)
     , m_SaveCachedNetwork(false)
     , m_NumberOfThreads(0)
+    , m_EnableAsyncModelExecution(false)
+    , m_ArmnnNumberOfThreads(1)
 {
 }
 
@@ -66,6 +70,8 @@
     , m_ShouldExit(false)
     , m_SaveCachedNetwork(false)
     , m_NumberOfThreads(0)
+    , m_EnableAsyncModelExecution(false)
+    , m_ArmnnNumberOfThreads(1)
 {
     std::string unsupportedOperationsAsString;
     std::string clTunedParametersModeAsString;
@@ -154,7 +160,16 @@
          cxxopts::value<bool>(m_VerboseLogging)->default_value("false"))
 
         ("V,version", "Show version information",
-         cxxopts::value<bool>(showVersion)->default_value("false"));
+         cxxopts::value<bool>(showVersion)->default_value("false"))
+
+        ("A,asyncModelExecution", "Enable AsynModel Execution",
+         cxxopts::value<bool>(m_EnableAsyncModelExecution)->default_value("false"))
+
+        ("T,armnn-threads",
+         "Assign the number of threads used by ArmNN. "
+         "Input value must be at least 1. "
+         "Default is set to 1.",
+         cxxopts::value<unsigned int>(m_ArmnnNumberOfThreads)->default_value("1"));
     }
     catch (const std::exception& e)
     {
diff --git a/DriverOptions.hpp b/DriverOptions.hpp
index 8b3f574..e1d25c4 100644
--- a/DriverOptions.hpp
+++ b/DriverOptions.hpp
@@ -40,6 +40,8 @@
     const std::string& GetCachedNetworkFilePath() const { return m_CachedNetworkFilePath; }
     bool SaveCachedNetwork() const { return m_SaveCachedNetwork; }
     unsigned int GetNumberOfThreads() const { return m_NumberOfThreads; }
+    bool isAsyncModelExecutionEnabled() const { return m_EnableAsyncModelExecution; };
+    unsigned int getNoOfArmnnThreads() const { return m_ArmnnNumberOfThreads; };
 
 private:
     std::vector<armnn::BackendId> m_Backends;
@@ -59,6 +61,8 @@
     std::string m_CachedNetworkFilePath;
     bool m_SaveCachedNetwork;
     unsigned int m_NumberOfThreads;
+    bool m_EnableAsyncModelExecution;
+    unsigned int m_ArmnnNumberOfThreads;
 };
 
 } // namespace armnn_driver