IVGCVSW-5781 Add Async Support to Android-NN-Driver

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I1f13d04100fdb119495b9e3054425bf3babc59f1
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index 13d7494..6c630c5 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -44,7 +44,8 @@
                            armnn::IRuntime* runtime,
                            const HalModel& model,
                            const std::string& requestInputsAndOutputsDumpDir,
-                           const bool gpuProfilingEnabled);
+                           const bool gpuProfilingEnabled,
+                           const bool asyncModelExecutionEnabled = false);
 
     virtual ~ArmnnPreparedModel_1_2();
 
@@ -76,6 +77,57 @@
     bool ExecuteWithDummyInputs();
 
 private:
+
+    template<typename CallbackContext>
+    class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
+    {
+    public:
+        ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
+                                    std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                                    std::vector<V1_2::OutputShape> outputShapes,
+                                    std::shared_ptr<armnn::InputTensors>& inputTensors,
+                                    std::shared_ptr<armnn::OutputTensors>& outputTensors,
+                                    CallbackContext callbackContext) :
+                m_Model(model),
+                m_MemPools(pMemPools),
+                m_OutputShapes(outputShapes),
+                m_InputTensors(inputTensors),
+                m_OutputTensors(outputTensors),
+                m_CallbackContext(callbackContext)
+        {}
+
+        void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
+
+        // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
+        virtual armnn::Status GetStatus() const override
+        {
+            return armnn::Status::Success;
+        }
+
+        // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
+        virtual void Wait() const override
+        {}
+
+        // Retrieve the start time before executing the inference
+        virtual armnn::HighResolutionClock GetStartTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        // Retrieve the time after executing the inference
+        virtual armnn::HighResolutionClock GetEndTime() const override
+        {
+            return std::chrono::high_resolution_clock::now();
+        }
+
+        ArmnnPreparedModel_1_2<HalVersion>* m_Model;
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
+        std::vector<V1_2::OutputShape> m_OutputShapes;
+        std::shared_ptr<armnn::InputTensors> m_InputTensors;
+        std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
+        CallbackContext m_CallbackContext;
+    };
+
     Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
                                       V1_2::MeasureTiming measureTiming,
                                       CallbackAsync_1_2 callback);
@@ -101,6 +153,14 @@
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
 
+    /// schedule the graph prepared from the request for execution
+    template<typename CallbackContext>
+    void ScheduleGraphForExecution(
+            std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+            std::shared_ptr<armnn::InputTensors>& inputTensors,
+            std::shared_ptr<armnn::OutputTensors>& outputTensors,
+            CallbackContext m_CallbackContext);
+
     armnn::NetworkId                                                            m_NetworkId;
     armnn::IRuntime*                                                            m_Runtime;
     V1_2::Model                                                                 m_Model;
@@ -112,6 +172,9 @@
     uint32_t                                                                    m_RequestCount;
     const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
     const bool                                                                  m_GpuProfilingEnabled;
+
+    std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
+    const bool m_AsyncModelExecutionEnabled;
 };
 
 }