Change the semantics of RefTensorHandle::Import to 'overlay' existing memory

This makes it possible to call Import on an Allocated() or memory-managed Tensor,
which is needed for the current implementation of OptimizerOptions::m_ExportEnabled
to work (as the last layer before the OutputLayer needs to be able to Import the
user's OutputTensor, but this is done after other memory allocation).

Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: I1a885c2da7b1f0f3964ae53b8135b5e96a66614f
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index eccdc26..dbfa374 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,8 +12,7 @@
     m_MemoryManager(memoryManager),
     m_Pool(nullptr),
     m_UnmanagedMemory(nullptr),
-    m_Imported(false),
-    m_IsImportEnabled(false)
+    m_ImportedMemory(nullptr)
 {
 
 }
@@ -22,59 +21,46 @@
                                  : m_TensorInfo(tensorInfo),
                                    m_Pool(nullptr),
                                    m_UnmanagedMemory(nullptr),
-                                   m_Imported(false),
-                                   m_IsImportEnabled(true)
+                                   m_ImportedMemory(nullptr)
 {
 
 }
 
 RefTensorHandle::~RefTensorHandle()
 {
-    if (!m_Pool)
-    {
-        // unmanaged
-        if (!m_Imported)
-        {
-            ::operator delete(m_UnmanagedMemory);
-        }
-    }
+    ::operator delete(m_UnmanagedMemory);
 }
 
 void RefTensorHandle::Manage()
 {
-    if (!m_IsImportEnabled)
-    {
-        ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
-        ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
+    ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
+    ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
 
+    if (m_MemoryManager)
+    {
         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
     }
 }
 
 void RefTensorHandle::Allocate()
 {
-    // If import is enabled, do not allocate the tensor
-    if (!m_IsImportEnabled)
+    if (!m_UnmanagedMemory)
     {
-
-        if (!m_UnmanagedMemory)
+        if (!m_Pool)
         {
-            if (!m_Pool)
-            {
-                // unmanaged
-                m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
-            }
-            else
-            {
-                m_MemoryManager->Allocate(m_Pool);
-            }
+            // unmanaged
+            m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
         }
         else
         {
-            throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
-                                           "that already has allocated memory.");
+            m_MemoryManager->Allocate(m_Pool);
         }
     }
+    else
+    {
+        throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
+                                       "that already has allocated memory.");
+    }
 }
 
 const void* RefTensorHandle::Map(bool /*unused*/) const
@@ -84,7 +70,11 @@
 
 void* RefTensorHandle::GetPointer() const
 {
-    if (m_UnmanagedMemory)
+    if (m_ImportedMemory)
+    {
+        return m_ImportedMemory;
+    }
+    else if (m_UnmanagedMemory)
     {
         return m_UnmanagedMemory;
     }
@@ -114,51 +104,22 @@
 
 MemorySourceFlags RefTensorHandle::GetImportFlags() const
 {
-    if (m_IsImportEnabled)
-    {
-        return static_cast<MemorySourceFlags>(MemorySource::Malloc);
-    }
-    else
-    {
-        return static_cast<MemorySourceFlags>(MemorySource::Undefined);
-    }
+    return static_cast<MemorySourceFlags>(MemorySource::Malloc);
 }
 
 bool RefTensorHandle::Import(void* memory, MemorySource source)
 {
-    if (m_IsImportEnabled && source == MemorySource::Malloc)
+    if (source == MemorySource::Malloc)
     {
         // Check memory alignment
         if(!CanBeImported(memory, source))
         {
-            if (m_Imported)
-            {
-                m_Imported = false;
-                m_UnmanagedMemory = nullptr;
-            }
+            m_ImportedMemory = nullptr;
             return false;
         }
 
-        // m_UnmanagedMemory not yet allocated.
-        if (!m_Imported && !m_UnmanagedMemory)
-        {
-            m_UnmanagedMemory = memory;
-            m_Imported = true;
-            return true;
-        }
-
-        // m_UnmanagedMemory initially allocated with Allocate().
-        if (!m_Imported && m_UnmanagedMemory)
-        {
-            return false;
-        }
-
-        // m_UnmanagedMemory previously imported.
-        if (m_Imported)
-        {
-            m_UnmanagedMemory = memory;
-            return true;
-        }
+        m_ImportedMemory = memory;
+        return true;
     }
 
     return false;
@@ -166,7 +127,7 @@
 
 bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
 {
-    if (m_IsImportEnabled && source == MemorySource::Malloc)
+    if (source == MemorySource::Malloc)
     {
         uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
         if (reinterpret_cast<uintptr_t>(memory) % alignment)
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index d916b39..b4dedd5 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -71,8 +71,7 @@
     std::shared_ptr<RefMemoryManager> m_MemoryManager;
     RefMemoryManager::Pool* m_Pool;
     mutable void* m_UnmanagedMemory;
-    bool m_Imported;
-    bool m_IsImportEnabled;
+    void* m_ImportedMemory;
 };
 
 }
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index b5fcc21..883df6f 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -99,8 +99,14 @@
     memoryManager->Release();
 
     float testPtr[2] = { 2.5f, 5.5f };
-    // Cannot import as import is disabled
-    CHECK(!handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+    // Check import overlays contents
+    CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+    {
+        float* buffer = reinterpret_cast<float*>(handle->Map());
+        CHECK(buffer != nullptr); // Yields a valid pointer
+        CHECK(buffer[0] == 2.5f); // Memory is writable and readable
+        CHECK(buffer[1] == 5.5f); // Memory is writable and readable
+    }
 }
 
 TEST_CASE("RefTensorHandleFactoryImport")
@@ -115,11 +121,12 @@
     handle->Allocate();
     memoryManager->Acquire();
 
-    // No buffer allocated when import is enabled
-    CHECK_THROWS_AS(handle->Map(), armnn::NullPointerException);
+    // Check storage has been allocated
+    void* unmanagedStorage = handle->Map();
+    CHECK(unmanagedStorage != nullptr);
 
+    // Check importing overlays the storage
     float testPtr[2] = { 2.5f, 5.5f };
-    // Correctly import
     CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
     float* buffer = reinterpret_cast<float*>(handle->Map());
     CHECK(buffer != nullptr); // Yields a valid pointer after import
@@ -142,11 +149,11 @@
     handle.Manage();
     handle.Allocate();
 
-    // No buffer allocated when import is enabled
-    CHECK_THROWS_AS(handle.Map(), armnn::NullPointerException);
+    // Check unmanaged memory allocated 
+    CHECK(handle.Map());
 
     float testPtr[2] = { 2.5f, 5.5f };
-    // Correctly import
+    // Check imoport overlays the unamaged memory
     CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc));
     float* buffer = reinterpret_cast<float*>(handle.Map());
     CHECK(buffer != nullptr); // Yields a valid pointer after import