Refactor: Remove m_ImportFlags from RefTensorHandle

The import flags for a RefTensorHandle shouldn't be a data member,
as RefTensorHandle can only import from MemorySource::Malloc. Instead,
use m_ImportEnabled to determine what to return from GetImportFlags().

Simplifies the code in Import and CanBeImported.

Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: Ic629858920f7dd32f99ee27f150b81d8b67144cf
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index e196b61..eccdc26 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,19 +12,16 @@
     m_MemoryManager(memoryManager),
     m_Pool(nullptr),
     m_UnmanagedMemory(nullptr),
-    m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
     m_Imported(false),
     m_IsImportEnabled(false)
 {
 
 }
 
-RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
-                                 MemorySourceFlags importFlags)
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
                                  : m_TensorInfo(tensorInfo),
                                    m_Pool(nullptr),
                                    m_UnmanagedMemory(nullptr),
-                                   m_ImportFlags(importFlags),
                                    m_Imported(false),
                                    m_IsImportEnabled(true)
 {
@@ -115,43 +112,52 @@
     memcpy(dest, src, m_TensorInfo.GetNumBytes());
 }
 
+MemorySourceFlags RefTensorHandle::GetImportFlags() const
+{
+    if (m_IsImportEnabled)
+    {
+        return static_cast<MemorySourceFlags>(MemorySource::Malloc);
+    }
+    else
+    {
+        return static_cast<MemorySourceFlags>(MemorySource::Undefined);
+    }
+}
+
 bool RefTensorHandle::Import(void* memory, MemorySource source)
 {
-    if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+    if (m_IsImportEnabled && source == MemorySource::Malloc)
     {
-        if (m_IsImportEnabled && source == MemorySource::Malloc)
+        // Check memory alignment
+        if(!CanBeImported(memory, source))
         {
-            // Check memory alignment
-            if(!CanBeImported(memory, source))
-            {
-                if (m_Imported)
-                {
-                    m_Imported = false;
-                    m_UnmanagedMemory = nullptr;
-                }
-                return false;
-            }
-
-            // m_UnmanagedMemory not yet allocated.
-            if (!m_Imported && !m_UnmanagedMemory)
-            {
-                m_UnmanagedMemory = memory;
-                m_Imported = true;
-                return true;
-            }
-
-            // m_UnmanagedMemory initially allocated with Allocate().
-            if (!m_Imported && m_UnmanagedMemory)
-            {
-                return false;
-            }
-
-            // m_UnmanagedMemory previously imported.
             if (m_Imported)
             {
-                m_UnmanagedMemory = memory;
-                return true;
+                m_Imported = false;
+                m_UnmanagedMemory = nullptr;
             }
+            return false;
+        }
+
+        // m_UnmanagedMemory not yet allocated.
+        if (!m_Imported && !m_UnmanagedMemory)
+        {
+            m_UnmanagedMemory = memory;
+            m_Imported = true;
+            return true;
+        }
+
+        // m_UnmanagedMemory initially allocated with Allocate().
+        if (!m_Imported && m_UnmanagedMemory)
+        {
+            return false;
+        }
+
+        // m_UnmanagedMemory previously imported.
+        if (m_Imported)
+        {
+            m_UnmanagedMemory = memory;
+            return true;
         }
     }
 
@@ -160,17 +166,14 @@
 
 bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
 {
-    if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+    if (m_IsImportEnabled && source == MemorySource::Malloc)
     {
-        if (m_IsImportEnabled && source == MemorySource::Malloc)
+        uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+        if (reinterpret_cast<uintptr_t>(memory) % alignment)
         {
-            uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
-            if (reinterpret_cast<uintptr_t>(memory) % alignment)
-            {
-                return false;
-            }
-            return true;
+            return false;
         }
+        return true;
     }
     return false;
 }
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index a7eab03..d916b39 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -17,7 +17,7 @@
 public:
     RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
 
-    RefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
+    RefTensorHandle(const TensorInfo& tensorInfo);
 
     ~RefTensorHandle();
 
@@ -51,10 +51,7 @@
         return m_TensorInfo;
     }
 
-    virtual MemorySourceFlags GetImportFlags() const override
-    {
-        return m_ImportFlags;
-    }
+    virtual MemorySourceFlags GetImportFlags() const override;
 
     virtual bool Import(void* memory, MemorySource source) override;
     virtual bool CanBeImported(void* memory, MemorySource source) override;
@@ -74,7 +71,6 @@
     std::shared_ptr<RefMemoryManager> m_MemoryManager;
     RefMemoryManager::Pool* m_Pool;
     mutable void* m_UnmanagedMemory;
-    MemorySourceFlags m_ImportFlags;
     bool m_Imported;
     bool m_IsImportEnabled;
 };
diff --git a/src/backends/reference/RefTensorHandleFactory.cpp b/src/backends/reference/RefTensorHandleFactory.cpp
index ade27dd..da3b798 100644
--- a/src/backends/reference/RefTensorHandleFactory.cpp
+++ b/src/backends/reference/RefTensorHandleFactory.cpp
@@ -48,7 +48,7 @@
     }
     else
     {
-        return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+        return std::make_unique<RefTensorHandle>(tensorInfo);
     }
 }
 
@@ -63,7 +63,7 @@
     }
     else
     {
-        return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+        return std::make_unique<RefTensorHandle>(tensorInfo);
     }
 }
 
@@ -87,4 +87,4 @@
     return m_ImportFlags;
 }
 
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 69f75ca..bfe37d7 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -119,7 +119,7 @@
     }
     else
     {
-        return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+        return std::make_unique<RefTensorHandle>(tensorInfo);
     }
 }
 
@@ -137,7 +137,7 @@
     }
     else
     {
-        return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+        return std::make_unique<RefTensorHandle>(tensorInfo);
     }
 }
 
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 6f608e8..b5fcc21 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -137,7 +137,7 @@
 TEST_CASE("RefTensorHandleImport")
 {
     TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     handle.Manage();
     handle.Allocate();
@@ -224,7 +224,7 @@
 TEST_CASE("CheckSourceType")
 {
     TensorInfo info({1}, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     int* testPtr = new int(4);
 
@@ -243,7 +243,7 @@
 TEST_CASE("ReusePointer")
 {
     TensorInfo info({1}, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     int* testPtr = new int(4);
 
@@ -258,7 +258,7 @@
 TEST_CASE("MisalignedPointer")
 {
     TensorInfo info({2}, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     // Allocate a 2 int array
     int* testPtr = new int[2];
@@ -274,7 +274,7 @@
 TEST_CASE("CheckCanBeImported")
 {
     TensorInfo info({1}, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     int* testPtr = new int(4);
 
@@ -291,7 +291,7 @@
 TEST_CASE("MisalignedCanBeImported")
 {
     TensorInfo info({2}, DataType::Float32);
-    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+    RefTensorHandle handle(info);
 
     // Allocate a 2 int array
     int* testPtr = new int[2];