IVGCVSW-5816 Constant memory access
* Add new class ManagedConstTensorHandle to Unmap when out of scope
* Integrate into existing layers that have constants
* Add unit tests
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: I0a05e14e438804b37e9862e76b5ca329483f6b45
diff --git a/src/backends/backendsCommon/CpuTensorHandle.hpp b/src/backends/backendsCommon/CpuTensorHandle.hpp
index a300fe0..fdd2439 100644
--- a/src/backends/backendsCommon/CpuTensorHandle.hpp
+++ b/src/backends/backendsCommon/CpuTensorHandle.hpp
@@ -175,4 +175,71 @@
template <>
void* CpuTensorHandle::GetTensor() const;
+class ManagedConstTensorHandle
+{
+
+public:
+ explicit ManagedConstTensorHandle(std::shared_ptr<ConstCpuTensorHandle> ptr)
+ : m_Mapped(false)
+ , m_TensorHandle(std::move(ptr)) {};
+
+ /// RAII Managed resource Unmaps MemoryArea once out of scope
+ const void* Map(bool blocking = true)
+ {
+ if (m_TensorHandle)
+ {
+ auto pRet = m_TensorHandle->Map(blocking);
+ m_Mapped = true;
+ return pRet;
+ }
+ else
+ {
+ throw armnn::Exception("Attempting to Map null TensorHandle");
+ }
+
+ }
+
+ // Delete copy constructor as it's unnecessary
+ ManagedConstTensorHandle(const ConstCpuTensorHandle& other) = delete;
+
+ // Delete copy assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
+
+ // Delete move assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
+
+ ~ManagedConstTensorHandle()
+ {
+ // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
+ if (m_TensorHandle)
+ {
+ Unmap();
+ }
+ }
+
+ void Unmap()
+ {
+ // Only unmap if mapped and TensorHandle exists.
+ if (m_Mapped && m_TensorHandle)
+ {
+ m_TensorHandle->Unmap();
+ m_Mapped = false;
+ }
+ }
+
+ const TensorInfo& GetTensorInfo() const
+ {
+ return m_TensorHandle->GetTensorInfo();
+ }
+
+ bool IsMapped() const
+ {
+ return m_Mapped;
+ }
+
+private:
+ bool m_Mapped;
+ std::shared_ptr<ConstCpuTensorHandle> m_TensorHandle;
+};
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp b/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
index 0d45952..56a794e 100644
--- a/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
+++ b/src/backends/backendsCommon/test/DefaultAsyncExecuteTest.cpp
@@ -243,7 +243,6 @@
ValidateTensor(workingMemDescriptor2.m_Inputs[0], expectedExecuteval2);
}
-
BOOST_AUTO_TEST_SUITE_END()
}
\ No newline at end of file
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 1ef6de9..b04d9d6 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -167,6 +167,39 @@
ARMNN_ASSERT(!(handleFactory.SupportsInPlaceComputation()));
}
+BOOST_AUTO_TEST_CASE(TestManagedConstTensorHandle)
+{
+ // Initialize arguments
+ void* mem = nullptr;
+ TensorInfo info;
+
+ // Use PassthroughCpuTensor as others are abstract
+ auto passThroughHandle = std::make_shared<PassthroughCpuTensorHandle>(info, mem);
+
+ // Test managed handle is initialized with m_Mapped unset and once Map() called its set
+ ManagedConstTensorHandle managedHandle(passThroughHandle);
+ BOOST_CHECK(!managedHandle.IsMapped());
+ managedHandle.Map();
+ BOOST_CHECK(managedHandle.IsMapped());
+
+ // Test it can then be unmapped
+ managedHandle.Unmap();
+ BOOST_CHECK(!managedHandle.IsMapped());
+
+ // Test member function
+ BOOST_CHECK(managedHandle.GetTensorInfo() == info);
+
+ // Test that nullptr tensor handle doesn't get mapped
+ ManagedConstTensorHandle managedHandleNull(nullptr);
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+ BOOST_CHECK_THROW(managedHandleNull.Map(), armnn::Exception);
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+
+ // Check Unmap() when m_Mapped already false
+ managedHandleNull.Unmap();
+ BOOST_CHECK(!managedHandleNull.IsMapped());
+}
+
#if !defined(__ANDROID__)
// Only run these tests on non Android platforms
BOOST_AUTO_TEST_CASE(CheckSourceType)