IVGCVSW-6312 Support pre-importing inputs

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: Ifc5e6f2e36767cb2a5cbf281d40ec9989b581abc
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index 345cdeb..908fe76 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -221,6 +221,14 @@
     TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
     TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
 
+
+    /// ImportInputs separates the importing and mapping of InputTensors from network execution.
+    /// Allowing for a set of InputTensors to be imported and mapped once, but used in execution many times.
+    /// This function is not thread safe and must not be used while other threads are calling Execute().
+    /// Only compatible with AsyncEnabled networks
+    std::vector<ImportedInputId> ImportInputs(NetworkId networkId, const InputTensors& inputTensors);
+
+
     /// Evaluates a network using input in inputTensors and outputs filled into outputTensors
     Status EnqueueWorkload(NetworkId networkId,
                            const InputTensors& inputTensors,
@@ -232,7 +240,8 @@
     /// Will block until this and any other thread using the same workingMem object completes.
     Status Execute(IWorkingMemHandle& workingMemHandle,
                    const InputTensors& inputTensors,
-                   const OutputTensors& outputTensors);
+                   const OutputTensors& outputTensors,
+                   std::vector<ImportedInputId> preImportedInputs = {});
 
     /// Unloads a network from the IRuntime.
     /// At the moment this only removes the network from the m_Impl->m_Network.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 0b7b941..c3b439a 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -265,6 +265,7 @@
 
 /// Type of identifiers for bindable layers (inputs, outputs).
 using LayerBindingId = int;
+using ImportedInputId = unsigned int;
 
 class PermutationVector
 {
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 9b5748b..3d7173b 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -146,7 +146,7 @@
             IBackendInternal* backend = it.first->second.get();
 
             if (networkProperties.m_AsyncEnabled &&
-                HasCapability(BackendOptions::BackendOption{"AsyncExecution", false}, backend->GetCapabilities()))
+                !HasCapability(BackendOptions::BackendOption{"AsyncExecution", true}, backend->GetCapabilities()))
             {
                 std::string er = backend->GetId();
                 er += " does not support AsyncExecution";
@@ -156,7 +156,8 @@
             if (backend->SupportsTensorAllocatorAPI())
             {
                 auto workloadFactory = backend->CreateWorkloadFactory(
-                    m_TensorHandleFactoryRegistry, m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions(),
+                    m_TensorHandleFactoryRegistry,
+                    m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions(),
                     static_cast<MemorySourceFlags>(m_NetworkProperties.m_InputSource),
                     static_cast<MemorySourceFlags>(m_NetworkProperties.m_OutputSource));
                 m_WorkloadFactories.emplace(
@@ -857,29 +858,20 @@
     return success;
 }
 
-void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
-                                 const ConstTensor& inputTensor,
-                                 WorkingMemHandle& context)
+void LoadedNetwork::EnqueueInput(const ConstTensor& inputTensor,
+                                 ITensorHandle* inputTensorHandle)
 {
-    if (layer.GetType() != LayerType::Input)
-    {
-        throw InvalidArgumentException("EnqueueInput: given layer not an InputLayer");
-    }
-    LayerGuid id = layer.GetGuid();
-    WorkingMemDescriptor descriptor = context.GetWorkingMemDescriptor(id);
-
-    MemorySourceFlags importFlags = descriptor.m_Outputs[0]->GetImportFlags();
+    MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
     if (m_NetworkProperties.m_ImportEnabled)  // Try import the input tensor
     {
         if (CheckFlag(importFlags, m_NetworkProperties.m_InputSource) )
         {
-            // This assumes a CPU Tensor handle
             std::unique_ptr<ITensorHandle> tensorHandle =
                     std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(),
-                                                                      inputTensor.GetMemoryArea());
-
+                                                                   inputTensor.GetMemoryArea());
             void* mem = tensorHandle->Map(false);
-            if (descriptor.m_Outputs[0]->Import(mem, m_NetworkProperties.m_InputSource))
+
+            if (inputTensorHandle->Import(mem, m_NetworkProperties.m_InputSource))
             {
                 tensorHandle->Unmap();
                 return;
@@ -902,10 +894,7 @@
             memcpy(dst, src, size);
         };
 
-        for (const auto& input : descriptor.m_Outputs)
-        {
-            CopyTensorContentsGeneric(tensorHandle.get(), input, copyFunc);
-        }
+        CopyTensorContentsGeneric(tensorHandle.get(), inputTensorHandle, copyFunc);
     }
 }
 
@@ -1009,9 +998,78 @@
     throw InvalidArgumentException("Output does not exist.");
 }
 
+std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inputTensors)
+{
+    if (!m_NetworkProperties.m_ImportEnabled)  // Try import the input tensor
+    {
+        throw MemoryImportException("ImportInputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
+    }
+
+    std::vector<ImportedInputId> importedInputs;
+    Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
+
+    for (auto inputTensor : inputTensors)
+    {
+        auto layerBindingId = inputTensor.first;
+        auto it = std::find_if(graph.GetInputLayers().begin(), graph.GetInputLayers().end(), [=](auto* layer)
+        {
+            return layer->GetBindingId() == layerBindingId;
+        });
+
+        if (it == graph.GetInputLayers().end())
+        {
+            throw MemoryImportException("ImportInputs: Memory Import failed, backend does not support Import");
+        }
+
+        const Layer* layer = *it;
+        if (layer->GetType() != LayerType::Input)
+        {
+            throw InvalidArgumentException("ImportInputs: given layer not an InputLayer");
+        }
+
+        const OutputSlot& outputSlot = layer->GetOutputSlots()[0];
+
+        ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
+        const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
+
+        ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
+        ARMNN_ASSERT(handleFactory);
+
+        m_PreImportedInputHandles.emplace_back(layerBindingId,
+                                               handleFactory->CreateTensorHandle(tensorInfo, false));
+
+        ITensorHandle* tensorHandle = m_PreImportedInputHandles.back().m_TensorHandle.get();
+
+        if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_InputSource))
+        {
+            throw MemoryImportException(
+                fmt::format("ImportInputs: Memory Import failed, backend: {} does not support importing from source {}"
+                            , factoryId, m_NetworkProperties.m_InputSource));
+        }
+
+        std::unique_ptr<ITensorHandle> passThroughTensorHandle =
+                std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
+                                                               inputTensor.second.GetMemoryArea());
+
+        if (tensorHandle->Import(passThroughTensorHandle->Map(), m_NetworkProperties.m_InputSource))
+        {
+            importedInputs.push_back(m_CurImportedInputId++);
+            passThroughTensorHandle->Unmap();
+        }
+        else
+        {
+            passThroughTensorHandle->Unmap();
+            throw MemoryImportException("ImportInputs: Memory Import failed");
+        }
+    }
+
+    return importedInputs;
+}
+
 Status LoadedNetwork::Execute(const InputTensors& inputTensors,
                               const OutputTensors& outputTensors,
-                              IWorkingMemHandle& iWorkingMemHandle)
+                              IWorkingMemHandle& iWorkingMemHandle,
+                              std::vector<ImportedInputId> preImportedInputs)
 {
     const Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
 
@@ -1022,9 +1080,72 @@
         return Status::Failure;
     }
 
-    if (graph.GetNumInputs() != inputTensors.size())
+    if (inputTensors.size() + preImportedInputs.size() != graph.GetNumInputs() )
     {
-        throw InvalidArgumentException("Number of inputs provided does not match network.");
+        if (preImportedInputs.empty())
+        {
+            throw InvalidArgumentException("Number of inputs provided does not match network.");
+        }
+        else
+        {
+            throw InvalidArgumentException("Number of inputs + preImportedInputs provided does not match network.");
+        }
+    }
+
+    WorkingMemHandle& workingMemHandle = dynamic_cast<WorkingMemHandle&>(iWorkingMemHandle);
+
+    // This map is a quick way to check for duplicate or non-existing LayerBindingIds
+    std::unordered_map<LayerBindingId, bool> validationMap = workingMemHandle.GetValidationMap();
+    for (auto pair : inputTensors)
+    {
+        const LayerBindingId layerBindingId = pair.first;
+
+        try
+        {
+            bool& previouslyUsed = validationMap.at(pair.first);
+            if (previouslyUsed)
+            {
+                throw InvalidArgumentException(fmt::format("Duplicate LayerbindingId: {} ", layerBindingId));
+            }
+            else
+            {
+                previouslyUsed = true;
+            }
+        }
+        catch (std::out_of_range)
+        {
+            throw InvalidArgumentException(fmt::format("Unknown LayerBindingId id: {}", layerBindingId));
+        }
+    }
+
+    if (!preImportedInputs.empty())
+    {
+        const unsigned int maxPreImportedId = *std::max_element(preImportedInputs.begin(), preImportedInputs.end());
+        if (maxPreImportedId > m_CurImportedInputId)
+        {
+            throw InvalidArgumentException(fmt::format("Invalid ImportedInputId: {}", maxPreImportedId));
+        }
+        for (ImportedInputId id : preImportedInputs)
+        {
+            const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
+
+            try
+            {
+                bool& previouslyUsed = validationMap.at(layerBindingId);
+                if (previouslyUsed)
+                {
+                    throw InvalidArgumentException(fmt::format("Duplicate LayerbindingId: {} ", layerBindingId));
+                }
+                else
+                {
+                    previouslyUsed = true;
+                }
+            }
+            catch (std::out_of_range)
+            {
+                throw InvalidArgumentException(fmt::format("Unknown LayerBindingId id: {}", layerBindingId));
+            }
+        }
     }
 
     std::unique_ptr<profiling::TimelineUtilityMethods> timelineUtils =
@@ -1050,8 +1171,6 @@
         timelineUtils->RecordEvent(inferenceGuid, profiling::LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
         timelineUtils->Commit();
     }
-    WorkingMemHandle& workingMemHandle = dynamic_cast<WorkingMemHandle&>(iWorkingMemHandle);
-    std::lock_guard<std::mutex> lockGuard(workingMemHandle.GetMutex());
 
     if (!workingMemHandle.IsAllocated())
     {
@@ -1060,9 +1179,23 @@
 
     {
         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs");
-        for (const BindableLayer* inputLayer : graph.GetInputLayers())
+        // Swap in the pre-imported inputs if any
+        for (ImportedInputId id : preImportedInputs)
         {
-            EnqueueInput(*inputLayer, GetInputTensor(inputLayer->GetBindingId(), inputTensors), workingMemHandle);
+            const ImportedInputHandlePin& importedInputPin = m_PreImportedInputHandles[id];
+
+            const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
+            ITensorHandle* preimportedHandle = importedInputPin.m_TensorHandle.get();
+            auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
+            for (auto it : inputConnections)
+            {
+                *it = preimportedHandle;
+            }
+        }
+
+        for (auto pair : inputTensors)
+        {
+            EnqueueInput(pair.second, workingMemHandle.GetInputHandle(pair.first));
         }
     }
 
@@ -1108,6 +1241,19 @@
         }
     }
 
+    // Restore the workingMemHandle to its original state
+    for (ImportedInputId id : preImportedInputs)
+    {
+        const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
+
+        auto inputHandle = workingMemHandle.GetInputHandle(layerBindingId);
+        auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
+        for (auto it : inputConnections)
+        {
+            *it = inputHandle;
+        }
+    }
+
     return executionSucceeded ? Status::Success : Status::Failure;
 }
 
@@ -1166,7 +1312,20 @@
         }
     };
 
-    std::unordered_map<const ITensorHandle*, unsigned int> handleReferenceCounts;
+    struct HandleInfo
+    {
+        unsigned int m_ReferenceCount = 0;
+        bool isInputLayer = false;
+        bool isOutputLayer = false;
+        LayerBindingId m_LayerBindingId = -1;
+    };
+
+    std::vector<WorkingMemHandle::InputConnectionInfo> inputConnections;
+    std::vector<std::pair<LayerBindingId, LayerGuid>> inputIndexes;
+
+    std::unordered_map<const ITensorHandle*, HandleInfo> handleReferenceCounts;
+
+    unsigned int workingMemDescriptorIndex = 0;
     for (auto&& layer : order)
     {
         WorkingMemDescriptor workingMemDescriptor;
@@ -1177,7 +1336,7 @@
             continue;
         }
         bool isMemoryManaged = true;
-        bool isInputLayer = true;
+        bool isInputLayer = false;
         // Look for the layer with 1 OutputSlot which has 1 connection and that connection is an Output Layer
         // If Export is enabled disable memory management so we can export, otherwise we do a copy
         if ((layer->GetNumOutputSlots() == 1) &&
@@ -1190,8 +1349,8 @@
         {
             // Input layers/workloads will not be executed so the descriptor is not added to workingMemDescriptors
             // However we will still need to manage the tensorHandle
-            isInputLayer = false;
-            isMemoryManaged = !m_NetworkProperties.m_ExportEnabled;
+            isInputLayer = true;
+            isMemoryManaged = !m_NetworkProperties.m_ImportEnabled;
         }
 
         // Create a tensor handle for each output slot of a layer
@@ -1206,7 +1365,17 @@
             unsigned int numConnections = slot.GetNumConnections();
             ARMNN_ASSERT(numConnections != 0);
 
-            handleReferenceCounts[tensorHandle] = numConnections;
+            handleReferenceCounts[tensorHandle].m_ReferenceCount = numConnections;
+
+            if (isInputLayer)
+            {
+                handleReferenceCounts[tensorHandle].isInputLayer = true;
+                LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
+
+                handleReferenceCounts[tensorHandle].m_LayerBindingId = bindingId;
+
+                inputIndexes.emplace_back(std::make_pair(bindingId, layer->GetGuid()));
+            }
         }
         // Loop through the input slots in the same layer and decrement the reference counter associated
         // to each tensor handle we encounter.
@@ -1230,8 +1399,18 @@
             unsigned int index = outputSlot->CalculateIndexOnOwner();
             ITensorHandle* inputTensorHandle = search->second[index].get();
             workingMemDescriptor.m_Inputs.push_back(inputTensorHandle);
-            --handleReferenceCounts.at(inputTensorHandle);
-            if (handleReferenceCounts.at(inputTensorHandle) == 0u)
+
+            HandleInfo& handleInfo = handleReferenceCounts.at(inputTensorHandle);
+
+            // Store the iterator to the
+            if (handleInfo.isInputLayer)
+            {
+                inputConnections.emplace_back(WorkingMemHandle::InputConnectionInfo{
+                        handleInfo.m_LayerBindingId, workingMemDescriptorIndex, slot.GetSlotIndex()});
+            }
+
+            --handleInfo.m_ReferenceCount;
+            if (handleInfo.m_ReferenceCount == 0u)
             {
                 // Stop managing lifetime of tensor handle
                 inputTensorHandle->Allocate();
@@ -1242,13 +1421,16 @@
 
         // Input layers/workloads will not be executed, so the descriptor is not added to workingMemDescriptors
         // However we will still need to manage the tensorHandle
-        if (isInputLayer)
+        if (!isInputLayer)
         {
             workingMemDescriptors.push_back(workingMemDescriptor);
+            workingMemDescriptorIndex++;
         }
     }
 
     return std::make_unique<WorkingMemHandle>(networkId,
+                                              inputIndexes,
+                                              inputConnections,
                                               workingMemDescriptors,
                                               workingMemDescriptorMap,
                                               memoryManagers,
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 360ad91..e713be2 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -49,13 +49,16 @@
     TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
     TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
 
+    std::vector<ImportedInputId> ImportInputs(const InputTensors& inputTensors);
+
     /// Single thread execution of the loaded network
     Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
 
     /// Thread safe execution of the loaded network
     Status Execute(const InputTensors& inputTensors,
                    const OutputTensors& outputTensors,
-                   IWorkingMemHandle& workingMemHandle);
+                   IWorkingMemHandle& workingMemHandle,
+                   std::vector<ImportedInputId> preImportedInputs = {});
 
     static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
                                                             std::string& errorMessage,
@@ -100,7 +103,7 @@
 
     void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
 
-    void EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& handle);
+    void EnqueueInput(const ConstTensor& inputTensor, ITensorHandle* inputTensorHandle);
 
     void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
 
@@ -130,6 +133,22 @@
     TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
 
     profiling::ProfilingService& m_ProfilingService;
+
+    struct ImportedInputHandlePin
+    {
+        ImportedInputHandlePin(LayerBindingId layerBindingId,
+                               std::unique_ptr<ITensorHandle> tensorHandle)
+        : m_LayerBindingId(layerBindingId)
+        , m_TensorHandle(std::move(tensorHandle))
+        {}
+
+        LayerBindingId m_LayerBindingId;
+        std::unique_ptr<ITensorHandle> m_TensorHandle;
+    };
+
+    std::vector<ImportedInputHandlePin> m_PreImportedInputHandles;
+
+    ImportedInputId m_CurImportedInputId = 0;
 };
 
 }
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index bbcbb9f..085cf2c 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -76,6 +76,12 @@
     return pRuntimeImpl->GetOutputTensorInfo(networkId, layerId);
 }
 
+std::vector<ImportedInputId> IRuntime::ImportInputs(NetworkId networkId, const InputTensors& inputTensors)
+{
+    return pRuntimeImpl->ImportInputs(networkId, inputTensors);
+}
+
+
 Status IRuntime::EnqueueWorkload(NetworkId networkId,
                                  const InputTensors& inputTensors,
                                  const OutputTensors& outputTensors)
@@ -85,9 +91,10 @@
 
 Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
                          const InputTensors& inputTensors,
-                         const OutputTensors& outputTensors)
+                         const OutputTensors& outputTensors,
+                         std::vector<ImportedInputId> preImportedInputs)
 {
-    return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
+    return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors, preImportedInputs);
 }
 
 Status IRuntime::UnloadNetwork(NetworkId networkId)
@@ -476,6 +483,12 @@
     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
 }
 
+std::vector<ImportedInputId> RuntimeImpl::ImportInputs(NetworkId networkId, const InputTensors& inputTensors)
+{
+    return GetLoadedNetworkPtr(networkId)->ImportInputs(inputTensors);
+}
+
+
 
 Status RuntimeImpl::EnqueueWorkload(NetworkId networkId,
                                 const InputTensors& inputTensors,
@@ -512,7 +525,8 @@
 
 Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
                             const InputTensors& inputTensors,
-                            const OutputTensors& outputTensors)
+                            const OutputTensors& outputTensors,
+                            std::vector<ImportedInputId> preImportedInputs)
 {
     NetworkId networkId = iWorkingMemHandle.GetNetworkId();
     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
@@ -531,7 +545,7 @@
 
     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
 
-    return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+    return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle, preImportedInputs);
 }
 
 /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index 7a80acd..e947dce 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -55,6 +55,8 @@
     TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
     TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
 
+    std::vector<ImportedInputId> ImportInputs(NetworkId networkId, const InputTensors& inputTensors);
+
     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
     Status EnqueueWorkload(NetworkId networkId,
                            const InputTensors& inputTensors,
@@ -66,7 +68,8 @@
     /// Will block until this and any other thread using the same workingMem object completes.
     Status Execute(IWorkingMemHandle& workingMemHandle,
                    const InputTensors& inputTensors,
-                   const OutputTensors& outputTensors);
+                   const OutputTensors& outputTensors,
+                   std::vector<ImportedInputId> preImportedInputs);
 
     /// Unloads a network from the Runtime.
     /// At the moment this only removes the network from the m_Impl->m_Network.
diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp
index 1dcaa38..e402684 100644
--- a/src/armnn/WorkingMemHandle.cpp
+++ b/src/armnn/WorkingMemHandle.cpp
@@ -14,20 +14,52 @@
 namespace experimental
 {
 
-WorkingMemHandle::WorkingMemHandle(
-        NetworkId networkId,
+WorkingMemHandle::WorkingMemHandle(NetworkId networkId,
+        std::vector<std::pair<LayerBindingId, LayerGuid>> inputHandles,
+        std::vector<InputConnectionInfo> inputConnections,
         std::vector<WorkingMemDescriptor> workingMemDescriptors,
         std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap,
         std::vector<std::shared_ptr<IMemoryManager>> memoryManagers,
-        std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > ownedTensorHandles) :
-    m_NetworkId(networkId),
-    m_WorkingMemDescriptors(workingMemDescriptors),
-    m_WorkingMemDescriptorMap(workingMemDescriptorMap),
-    m_MemoryManagers(memoryManagers),
-    m_OwnedTensorHandles(std::move(ownedTensorHandles)),
-    m_IsAllocated(false),
-    m_Mutex()
+        std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > ownedTensorHandles)
+    : m_NetworkId(networkId)
+    , m_WorkingMemDescriptors(workingMemDescriptors)
+    , m_WorkingMemDescriptorMap(workingMemDescriptorMap)
+    , m_MemoryManagers(memoryManagers)
+    , m_OwnedTensorHandles(std::move(ownedTensorHandles))
+    , m_IsAllocated(false)
+    , m_Mutex()
 {
+    unsigned int maxInputBindingId = 0;
+    for (auto pair : inputHandles)
+    {
+        unsigned int bindingId = numeric_cast<unsigned int>(pair.first);
+        if (maxInputBindingId < bindingId)
+        {
+            maxInputBindingId = bindingId;
+        }
+
+    }
+
+    // Create a map of LayerBindingIds to the corresponding input ITensorHandle*
+    for (auto pair : inputHandles)
+    {
+        m_InputHandleMap[pair.first] = m_WorkingMemDescriptorMap.at(pair.second).m_Outputs[0];
+        m_ValidationMap[pair.first] = false;
+    }
+
+    // For every input we need to store all locations from which that input's ITensorHandle* is read.
+    // So we can, at a later point, swap in and out the ITensorHandle* at that location.
+    for (auto inputConnectionInfo : inputConnections)
+    {
+        WorkingMemDescriptor& workingMemDescriptor = m_WorkingMemDescriptors[inputConnectionInfo.m_DescriptorIndex];
+
+        auto pos = workingMemDescriptor.m_Inputs.begin();
+        // The difference_type of a vector can be unsigned int or signed int depending on the std implementation
+        // This cast removes any conversion warnings
+        pos += numeric_cast<std::vector<ITensorHandle*>::difference_type>(inputConnectionInfo.m_InputIndex);
+
+        m_InputConnectionMap[inputConnectionInfo.m_LayerBindingId].push_back(pos);
+    }
 }
 
 void WorkingMemHandle::Allocate()
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp
index 5e3fd66..676d042 100644
--- a/src/armnn/WorkingMemHandle.hpp
+++ b/src/armnn/WorkingMemHandle.hpp
@@ -21,11 +21,23 @@
 namespace experimental
 {
 
+
 class WorkingMemHandle final : public IWorkingMemHandle
 {
 
 public:
+    struct InputConnectionInfo
+    {
+        LayerBindingId m_LayerBindingId;
+        unsigned int m_DescriptorIndex;
+        unsigned int m_InputIndex;
+    };
+
+    WorkingMemHandle(NetworkId networkId) : m_NetworkId(networkId){}
+
     WorkingMemHandle(NetworkId networkId,
+                     std::vector<std::pair<LayerBindingId, LayerGuid>> inputHandles,
+                     std::vector<InputConnectionInfo> inputConnections,
                      std::vector<WorkingMemDescriptor> workingMemDescriptors,
                      std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap,
                      std::vector<std::shared_ptr<IMemoryManager>> memoryManagers,
@@ -39,8 +51,6 @@
         return m_NetworkId;
     }
 
-
-
     /// Allocate the backing memory required for execution. If this is not called, then allocation will be
     /// deferred to execution time. The mutex must be locked.
     void Allocate() override;
@@ -75,10 +85,28 @@
         return m_WorkingMemDescriptors[id];
     }
 
+    ITensorHandle* GetInputHandle(LayerBindingId layerBindingId) const
+    {
+        return m_InputHandleMap.at(layerBindingId);
+    };
+
+    const std::vector<std::vector<ITensorHandle*>::iterator>& GetInputConnections(LayerBindingId layerBindingId) const
+    {
+        return m_InputConnectionMap.at(layerBindingId);
+    };
+
+    std::unordered_map<LayerBindingId, bool> GetValidationMap() const
+    {
+        return m_ValidationMap;
+    };
+
 private:
     NetworkId m_NetworkId;
     std::shared_ptr<ProfilerImpl> m_Profiler;
 
+    std::unordered_map<LayerBindingId, ITensorHandle*> m_InputHandleMap;
+    std::unordered_map<LayerBindingId, std::vector<std::vector<ITensorHandle*>::iterator>> m_InputConnectionMap;
+
     std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors;
     std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap;
 
@@ -88,6 +116,7 @@
     // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here
     std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > >  m_OwnedTensorHandles;
 
+    std::unordered_map<LayerBindingId, bool> m_ValidationMap;
     bool m_IsAllocated;
     std::mutex m_Mutex;
 };
diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp
index 4652d1c..abf13f5 100644
--- a/src/armnn/test/RuntimeTests.cpp
+++ b/src/armnn/test/RuntimeTests.cpp
@@ -61,6 +61,113 @@
     CHECK(runtime->UnloadNetwork(networkIdentifier1) == armnn::Status::Failure);
 }
 
+TEST_CASE("RuntimePreImportInputs")
+{
+    armnn::IRuntime::CreationOptions options;
+    armnn::IRuntimePtr               runtime(armnn::IRuntime::Create(options));
+
+    armnn::NetworkId   networkIdentifier1 = 1;
+
+    armnn::INetworkPtr testNetwork(armnn::INetwork::Create());
+    auto inputLayer1 = testNetwork->AddInputLayer(0, "input 1 layer");
+    auto inputLayer2 = testNetwork->AddInputLayer(1, "input 2 layer");
+    auto addLayer = testNetwork->AddAdditionLayer("add layer");
+    auto outputLayer = testNetwork->AddOutputLayer(2, "output layer");
+
+    TensorInfo tensorInfo{{4}, armnn::DataType::Signed32};
+
+    inputLayer1->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0));
+    inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+    inputLayer2->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1));
+    inputLayer2->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+    addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+    addLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+    std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
+
+    std::string er;
+    armnn::INetworkProperties networkProperties(true, MemorySource::Malloc, MemorySource::Undefined);
+    runtime->LoadNetwork(networkIdentifier1,
+                         Optimize(*testNetwork, backends, runtime->GetDeviceSpec()),
+                         er,
+                         networkProperties);
+
+    std::vector<int> inputData1(4, 10);
+    std::vector<int> inputData2(4, 20);
+    std::vector<int> output(4);
+
+    ConstTensor inputTensor1({{4}, armnn::DataType::Signed32}, inputData1.data());
+    ConstTensor inputTensor2({{4}, armnn::DataType::Signed32}, inputData2.data());
+
+    Tensor outputTensor({{4}, armnn::DataType::Signed32}, output.data());
+
+    auto importedInputVec1 = runtime->ImportInputs(networkIdentifier1, {{0, inputTensor1}});
+    CHECK(importedInputVec1.size() == 1);
+    CHECK(importedInputVec1[0] == 0);
+
+    auto memHandle = runtime->CreateWorkingMemHandle(networkIdentifier1);
+
+    runtime->Execute(*memHandle.get(), {{1, inputTensor2}}, {{2, outputTensor}}, {0 /* pre-imported id */});
+    for (auto val : output)
+    {
+        CHECK(val == 30);
+    }
+
+    auto importedInputVec2 = runtime->ImportInputs(networkIdentifier1, {{1, inputTensor2}});
+    CHECK(importedInputVec2.size() == 1);
+    CHECK(importedInputVec2[0] == 1);
+
+    runtime->Execute(*memHandle.get(), {{0, inputTensor1}}, {{2, outputTensor}}, {1 /* pre-imported id */});
+    for (auto val : output)
+    {
+        CHECK(val == 30);
+    }
+
+    runtime->Execute(*memHandle.get(), {}, {{2, outputTensor}}, {0, 1});
+    for (auto val : output)
+    {
+        CHECK(val == 30);
+    }
+
+    // Duplicate ImportedInputId and LayerBindingId
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(),{},{{2, outputTensor}},{0, 0});
+                    , armnn::InvalidArgumentException);
+
+    // Duplicate LayerBindingId
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {{1, inputTensor2}}, {{2, outputTensor}},{1});
+                    , armnn::InvalidArgumentException);
+
+    // Incorrect ImportedInputId
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {{1, inputTensor2}}, {{2, outputTensor}},{10});
+                    , armnn::InvalidArgumentException);
+
+    // Incorrect LayerBindingId
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {{-2, inputTensor2}}, {{2, outputTensor}},{1});
+                    , armnn::InvalidArgumentException);
+
+    // Incorrect layer binding id and ImportedInputId
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {{-2, inputTensor2}}, {{2, outputTensor}},{10});
+                    , armnn::InvalidArgumentException);
+
+
+    auto importedInputVec3 = runtime->ImportInputs(networkIdentifier1, {{1, inputTensor2}});
+    CHECK(importedInputVec3[0] == 2);
+    // Too many ImportedInputIds
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {}, {{2, outputTensor}},{0, 1, 2});
+                    , armnn::InvalidArgumentException);
+
+    // Too many InputTensors
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(),
+                                     {{0, inputTensor2}, {1, inputTensor2}, {2, inputTensor2}},
+                                     {{2, outputTensor}});
+                    , armnn::InvalidArgumentException);
+
+    // Too few ImportedInputIds
+    CHECK_THROWS_AS(runtime->Execute(*memHandle.get(), {}, {{2, outputTensor}},{0});
+                    , armnn::InvalidArgumentException);
+}
+
 // Note: the current builds we don't do valgrind and gperftools based leak checking at the same
 //       time, so in practice WITH_VALGRIND and ARMNN_LEAK_CHECKING_ENABLED are exclusive. The
 //       valgrind tests can stay for x86 builds, but on hikey Valgrind is just way too slow