Less code duplication in HAL 1.2

Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Change-Id: Ic2e8964745a4323efb1e06d466c0699f17a70c55
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index f609ef7..e68614a 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -19,18 +19,21 @@
 namespace armnn_driver
 {
 
-typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status,
-        std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
-        const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
-        std::string callingFunction)> armnnExecuteCallback_1_2;
+using CallbackAsync_1_2 = std::function<
+                                void(V1_0::ErrorStatus errorStatus,
+                                     std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
+                                     const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
+                                     std::string callingFunction)>;
 
-struct ArmnnCallback_1_2
+struct ExecutionContext_1_2
 {
-    armnnExecuteCallback_1_2 callback;
+    ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
+        ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
     TimePoint driverStart;
-    MeasureTiming measureTiming;
 };
 
+using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
+
 template <typename HalVersion>
 class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
 {
@@ -62,19 +65,38 @@
             configureExecutionBurst_cb cb) override;
 
     /// execute the graph prepared from the request
-    void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
-                      std::shared_ptr<armnn::InputTensors>& pInputTensors,
-                      std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
-                      ArmnnCallback_1_2 callbackDescriptor);
+    template<typename CallbackContext>
+    bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                      armnn::InputTensors& inputTensors,
+                      armnn::OutputTensors& outputTensors,
+                      CallbackContext callback);
 
     /// Executes this model with dummy inputs (e.g. all zeroes).
     /// \return false on failure, otherwise true
     bool ExecuteWithDummyInputs();
 
 private:
-    Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request,
-                                       MeasureTiming measureTiming,
-                                       armnnExecuteCallback_1_2 callback);
+    Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
+                                      MeasureTiming measureTiming,
+                                      CallbackAsync_1_2 callback);
+
+    Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
+            armnn::InputTensors& inputs,
+            const V1_0::Request& request,
+            const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+    Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
+            armnn::OutputTensors& outputs,
+            std::vector<OutputShape> &outputShapes,
+            const V1_0::Request& request,
+            const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+    Return <V1_0::ErrorStatus> PrepareMemoryForIO(
+            armnn::InputTensors& inputs,
+            armnn::OutputTensors& outputs,
+            std::vector<android::nn::RunTimePoolInfo>& memPools,
+            const V1_0::Request& request,
+            CallbackAsync_1_2 callback);
 
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
@@ -84,7 +106,9 @@
     V1_2::Model                                                                 m_Model;
     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
     // It is specific to this class, so it is declared as static here
-    static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread;
+    static RequestThread<ArmnnPreparedModel_1_2,
+                         HalVersion,
+                         CallbackContext_1_2>                                   m_RequestThread;
     uint32_t                                                                    m_RequestCount;
     const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
     const bool                                                                  m_GpuProfilingEnabled;