IVGCVSW-5818 Enable import on GPU

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I4e4eb107aa2bfa09625840d738001f33152e6792
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index b79576c..f097e67 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1165,7 +1165,8 @@
 // Find the handle factory for the input layer which results in fewest required copies.
 ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backends,
                                                             OutputSlot& slot,
-                                                            TensorHandleFactoryRegistry& registry)
+                                                            TensorHandleFactoryRegistry& registry,
+                                                            bool importEnabled)
 {
     Layer& layer = slot.GetOwningLayer();
     ARMNN_ASSERT(layer.GetType() == LayerType::Input);
@@ -1191,6 +1192,7 @@
 
     for (auto&& connection : slot.GetConnections())
     {
+
         const Layer& connectedLayer = connection->GetOwningLayer();
 
         auto toBackend = backends.find(connectedLayer.GetBackendId());
@@ -1208,11 +1210,12 @@
             // Input layers use the mem copy workload or import, so the selected factory must
             // support either the map/unmap API or Import API
             ITensorHandleFactory* factory = registry.GetFactory(dst);
-            if (!factory->SupportsMapUnmap() &&
-                !CheckFlag(factory->GetImportFlags(), MemorySource::Malloc)) // Just support cpu mem imports for now
+            if (importEnabled && factory->GetImportFlags() == 0)
             {
-                // The current tensor handle factory does not support the map/unmap or import
-                // strategy, move to the next one
+                continue;
+            }
+            else if (!importEnabled && !factory->SupportsMapUnmap())
+            {
                 continue;
             }
 
@@ -1257,7 +1260,8 @@
 // when considering all connections.
 ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
                                                     OutputSlot& outputSlot,
-                                                    TensorHandleFactoryRegistry& registry)
+                                                    TensorHandleFactoryRegistry& registry,
+                                                    bool importEnabled)
 {
     // First ensure the from backends can support the TensorHandeAPI
     Layer& layer = outputSlot.GetOwningLayer();
@@ -1268,14 +1272,13 @@
         return ITensorHandleFactory::LegacyFactoryId;
     }
 
-    // Connections to Output Layers requires support for map/unmap on the TensorHandle.
-    bool requiresMapUnmap = false;
+    bool outputConnection = false;
     for (auto&& connection : outputSlot.GetConnections())
     {
         const Layer& connectedLayer = connection->GetOwningLayer();
         if (connectedLayer.GetType() == LayerType::Output)
         {
-            requiresMapUnmap = true;
+            outputConnection = true;
         }
     }
 
@@ -1286,9 +1289,49 @@
     std::map<ITensorHandleFactory::FactoryId, int> factoryScores;
     for (auto&& pref : srcPrefs)
     {
-        if (requiresMapUnmap) // Only consider factories that support map/unmap if required
+        if (importEnabled)
         {
             ITensorHandleFactory* factory = registry.GetFactory(pref);
+            if (outputConnection)
+            {
+                // Check if this is fallback case
+                bool fallbackConnection = false;
+                for (auto&& inputSlot : layer.GetInputSlots())
+                {
+                        if (inputSlot.GetConnectedOutputSlot()->GetOwningLayer().GetBackendId() != layer.GetBackendId())
+                        {
+                            fallbackConnection = true;
+                        }
+                }
+                if (fallbackConnection)
+                {
+                    auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+                    // Cannot use factory import if fallback import is not supported.
+                    if (!factoryCap.empty())
+                    {
+                        continue;
+                    }
+                }
+                else if (factory->GetExportFlags() == 0)
+                {
+                    continue;
+                }
+            }
+            if (!outputConnection)
+            {
+                auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+                // Cannot use factory import if fallback import is not supported.
+                if (!factoryCap.empty())
+                {
+                    continue;
+                }
+            }
+
+        }
+        else
+        {
+            // Only consider factories that support map/unmap
+            ITensorHandleFactory* factory = registry.GetFactory(pref);
             if (!factory->SupportsMapUnmap())
             {
                 // The current tensor handle factory does not support the map/unmap strategy, move to the next one
@@ -1296,6 +1339,7 @@
             }
         }
 
+
         auto it = factoryScores.find(pref);
         if (it == factoryScores.end())
         {
@@ -1417,15 +1461,18 @@
             if (!dstFactory) {
                 continue;
             }
-
             if ((dstFactory->GetImportFlags() & srcFactory->GetExportFlags()) != 0)
             {
                 auto srcCapability = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::PaddingRequired);
                 auto dstCapability = dstFactory->GetCapabilities(&connectedLayer,
                                                                  &connectedLayer,
                                                                  CapabilityClass::PaddingRequired);
+                auto srcFallback = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+                auto dstFallback = dstFactory->GetCapabilities(&connectedLayer,
+                                                               &connectedLayer,
+                                                               CapabilityClass::FallbackImportDisabled);
                 // Do not require memory copy if the source and destination do not require padding.
-                if (srcCapability.empty() && dstCapability.empty())
+                if (srcCapability.empty() && dstCapability.empty() && srcFallback.empty() && dstFallback.empty())
                 {
                     return EdgeStrategy::ExportToTarget;
                 }
@@ -1477,13 +1524,13 @@
             switch(layer->GetType())
             {
                 case LayerType::Input:
-                    slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry);
+                    slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry, importEnabled);
                     break;
                 case LayerType::Output:
                     slotOption = CalculateSlotOptionForOutput(backends, outputSlot, registry);
                     break;
                 default:
-                    slotOption = CalculateSlotOption(backends, outputSlot, registry);
+                    slotOption = CalculateSlotOption(backends, outputSlot, registry, importEnabled);
                     break;
             }
             outputSlot.SetTensorHandleFactory(slotOption);