Change the semantics of RefTensorHandle::Import to 'overlay' existing memory
This makes it possible to call Import on an Allocated() or memory-managed Tensor,
which is needed for the current implementation of OptimizerOptions::m_ExportEnabled
to work (as the last layer before the OutputLayer needs to be able to Import the
user's OutputTensor, but this is done after other memory allocation).
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: I1a885c2da7b1f0f3964ae53b8135b5e96a66614f
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index eccdc26..dbfa374 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,8 +12,7 @@
m_MemoryManager(memoryManager),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_Imported(false),
- m_IsImportEnabled(false)
+ m_ImportedMemory(nullptr)
{
}
@@ -22,59 +21,46 @@
: m_TensorInfo(tensorInfo),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_Imported(false),
- m_IsImportEnabled(true)
+ m_ImportedMemory(nullptr)
{
}
RefTensorHandle::~RefTensorHandle()
{
- if (!m_Pool)
- {
- // unmanaged
- if (!m_Imported)
- {
- ::operator delete(m_UnmanagedMemory);
- }
- }
+ ::operator delete(m_UnmanagedMemory);
}
void RefTensorHandle::Manage()
{
- if (!m_IsImportEnabled)
- {
- ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
- ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
+ ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
+ ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
+ if (m_MemoryManager)
+ {
m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
}
}
void RefTensorHandle::Allocate()
{
- // If import is enabled, do not allocate the tensor
- if (!m_IsImportEnabled)
+ if (!m_UnmanagedMemory)
{
-
- if (!m_UnmanagedMemory)
+ if (!m_Pool)
{
- if (!m_Pool)
- {
- // unmanaged
- m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
- }
- else
- {
- m_MemoryManager->Allocate(m_Pool);
- }
+ // unmanaged
+ m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
}
else
{
- throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
- "that already has allocated memory.");
+ m_MemoryManager->Allocate(m_Pool);
}
}
+ else
+ {
+ throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
+ "that already has allocated memory.");
+ }
}
const void* RefTensorHandle::Map(bool /*unused*/) const
@@ -84,7 +70,11 @@
void* RefTensorHandle::GetPointer() const
{
- if (m_UnmanagedMemory)
+ if (m_ImportedMemory)
+ {
+ return m_ImportedMemory;
+ }
+ else if (m_UnmanagedMemory)
{
return m_UnmanagedMemory;
}
@@ -114,51 +104,22 @@
MemorySourceFlags RefTensorHandle::GetImportFlags() const
{
- if (m_IsImportEnabled)
- {
- return static_cast<MemorySourceFlags>(MemorySource::Malloc);
- }
- else
- {
- return static_cast<MemorySourceFlags>(MemorySource::Undefined);
- }
+ return static_cast<MemorySourceFlags>(MemorySource::Malloc);
}
bool RefTensorHandle::Import(void* memory, MemorySource source)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ if (source == MemorySource::Malloc)
{
// Check memory alignment
if(!CanBeImported(memory, source))
{
- if (m_Imported)
- {
- m_Imported = false;
- m_UnmanagedMemory = nullptr;
- }
+ m_ImportedMemory = nullptr;
return false;
}
- // m_UnmanagedMemory not yet allocated.
- if (!m_Imported && !m_UnmanagedMemory)
- {
- m_UnmanagedMemory = memory;
- m_Imported = true;
- return true;
- }
-
- // m_UnmanagedMemory initially allocated with Allocate().
- if (!m_Imported && m_UnmanagedMemory)
- {
- return false;
- }
-
- // m_UnmanagedMemory previously imported.
- if (m_Imported)
- {
- m_UnmanagedMemory = memory;
- return true;
- }
+ m_ImportedMemory = memory;
+ return true;
}
return false;
@@ -166,7 +127,7 @@
bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ if (source == MemorySource::Malloc)
{
uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
if (reinterpret_cast<uintptr_t>(memory) % alignment)
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index d916b39..b4dedd5 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -71,8 +71,7 @@
std::shared_ptr<RefMemoryManager> m_MemoryManager;
RefMemoryManager::Pool* m_Pool;
mutable void* m_UnmanagedMemory;
- bool m_Imported;
- bool m_IsImportEnabled;
+ void* m_ImportedMemory;
};
}
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index b5fcc21..883df6f 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -99,8 +99,14 @@
memoryManager->Release();
float testPtr[2] = { 2.5f, 5.5f };
- // Cannot import as import is disabled
- CHECK(!handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+ // Check import overlays contents
+ CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
+ {
+ float* buffer = reinterpret_cast<float*>(handle->Map());
+ CHECK(buffer != nullptr); // Yields a valid pointer
+ CHECK(buffer[0] == 2.5f); // Memory is writable and readable
+ CHECK(buffer[1] == 5.5f); // Memory is writable and readable
+ }
}
TEST_CASE("RefTensorHandleFactoryImport")
@@ -115,11 +121,12 @@
handle->Allocate();
memoryManager->Acquire();
- // No buffer allocated when import is enabled
- CHECK_THROWS_AS(handle->Map(), armnn::NullPointerException);
+ // Check storage has been allocated
+ void* unmanagedStorage = handle->Map();
+ CHECK(unmanagedStorage != nullptr);
+ // Check importing overlays the storage
float testPtr[2] = { 2.5f, 5.5f };
- // Correctly import
CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
float* buffer = reinterpret_cast<float*>(handle->Map());
CHECK(buffer != nullptr); // Yields a valid pointer after import
@@ -142,11 +149,11 @@
handle.Manage();
handle.Allocate();
- // No buffer allocated when import is enabled
- CHECK_THROWS_AS(handle.Map(), armnn::NullPointerException);
+ // Check unmanaged memory allocated
+ CHECK(handle.Map());
float testPtr[2] = { 2.5f, 5.5f };
- // Correctly import
+ // Check imoport overlays the unamaged memory
CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc));
float* buffer = reinterpret_cast<float*>(handle.Map());
CHECK(buffer != nullptr); // Yields a valid pointer after import