Refactor: Remove m_ImportFlags from RefTensorHandle

The import flags for a RefTensorHandle shouldn't be a data member,
as RefTensorHandle can only import from MemorySource::Malloc. Instead,
use m_ImportEnabled to determine what to return from GetImportFlags().

Simplifies the code in Import and CanBeImported.

Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: Ic629858920f7dd32f99ee27f150b81d8b67144cf
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index e196b61..eccdc26 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,19 +12,16 @@
     m_MemoryManager(memoryManager),
     m_Pool(nullptr),
     m_UnmanagedMemory(nullptr),
-    m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
     m_Imported(false),
     m_IsImportEnabled(false)
 {
 
 }
 
-RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
-                                 MemorySourceFlags importFlags)
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
                                  : m_TensorInfo(tensorInfo),
                                    m_Pool(nullptr),
                                    m_UnmanagedMemory(nullptr),
-                                   m_ImportFlags(importFlags),
                                    m_Imported(false),
                                    m_IsImportEnabled(true)
 {
@@ -115,43 +112,52 @@
     memcpy(dest, src, m_TensorInfo.GetNumBytes());
 }
 
+MemorySourceFlags RefTensorHandle::GetImportFlags() const
+{
+    if (m_IsImportEnabled)
+    {
+        return static_cast<MemorySourceFlags>(MemorySource::Malloc);
+    }
+    else
+    {
+        return static_cast<MemorySourceFlags>(MemorySource::Undefined);
+    }
+}
+
 bool RefTensorHandle::Import(void* memory, MemorySource source)
 {
-    if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+    if (m_IsImportEnabled && source == MemorySource::Malloc)
     {
-        if (m_IsImportEnabled && source == MemorySource::Malloc)
+        // Check memory alignment
+        if(!CanBeImported(memory, source))
         {
-            // Check memory alignment
-            if(!CanBeImported(memory, source))
-            {
-                if (m_Imported)
-                {
-                    m_Imported = false;
-                    m_UnmanagedMemory = nullptr;
-                }
-                return false;
-            }
-
-            // m_UnmanagedMemory not yet allocated.
-            if (!m_Imported && !m_UnmanagedMemory)
-            {
-                m_UnmanagedMemory = memory;
-                m_Imported = true;
-                return true;
-            }
-
-            // m_UnmanagedMemory initially allocated with Allocate().
-            if (!m_Imported && m_UnmanagedMemory)
-            {
-                return false;
-            }
-
-            // m_UnmanagedMemory previously imported.
             if (m_Imported)
             {
-                m_UnmanagedMemory = memory;
-                return true;
+                m_Imported = false;
+                m_UnmanagedMemory = nullptr;
             }
+            return false;
+        }
+
+        // m_UnmanagedMemory not yet allocated.
+        if (!m_Imported && !m_UnmanagedMemory)
+        {
+            m_UnmanagedMemory = memory;
+            m_Imported = true;
+            return true;
+        }
+
+        // m_UnmanagedMemory initially allocated with Allocate().
+        if (!m_Imported && m_UnmanagedMemory)
+        {
+            return false;
+        }
+
+        // m_UnmanagedMemory previously imported.
+        if (m_Imported)
+        {
+            m_UnmanagedMemory = memory;
+            return true;
         }
     }
 
@@ -160,17 +166,14 @@
 
 bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
 {
-    if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+    if (m_IsImportEnabled && source == MemorySource::Malloc)
     {
-        if (m_IsImportEnabled && source == MemorySource::Malloc)
+        uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+        if (reinterpret_cast<uintptr_t>(memory) % alignment)
         {
-            uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
-            if (reinterpret_cast<uintptr_t>(memory) % alignment)
-            {
-                return false;
-            }
-            return true;
+            return false;
         }
+        return true;
     }
     return false;
 }