IVGCVSW-6672 Implement CanBeImported function to RefTensorHandle

Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Change-Id: Icaa3aa7ef3e5cc3984941d095edfe8f0b2137879
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index 5229e9d..0be9708 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -122,8 +122,7 @@
         if (m_IsImportEnabled && source == MemorySource::Malloc)
         {
             // Check memory alignment
-            uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
-            if (reinterpret_cast<uintptr_t>(memory) % alignment)
+            if(!CanBeImported(memory, source))
             {
                 if (m_Imported)
                 {
@@ -160,4 +159,24 @@
     return false;
 }
 
+bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
+{
+    if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+    {
+        if (m_IsImportEnabled && source == MemorySource::Malloc)
+        {
+            uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+            if (reinterpret_cast<uintptr_t>(memory) % alignment)
+            {
+                return false;
+            }
+
+            return true;
+
+        }
+
+    }
+    return false;
+}
+
 }
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index a3264f5..a7eab03 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -57,6 +57,7 @@
     }
 
     virtual bool Import(void* memory, MemorySource source) override;
+    virtual bool CanBeImported(void* memory, MemorySource source) override;
 
 private:
     // Only used for testing
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 39f5a2a..3504f53 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -253,6 +253,39 @@
     delete[] testPtr;
 }
 
+TEST_CASE("CheckCanBeImported")
+{
+    TensorInfo info({1}, DataType::Float32);
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+
+    int* testPtr = new int(4);
+
+    // Not supported
+    CHECK(!handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::DmaBuf));
+
+    // Supported
+    CHECK(handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::Malloc));
+
+    delete testPtr;
+
+}
+
+TEST_CASE("MisalignedCanBeImported")
+{
+    TensorInfo info({2}, DataType::Float32);
+    RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+
+    // Allocate a 2 int array
+    int* testPtr = new int[2];
+
+    // Increment pointer by 1 byte
+    void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
+
+    CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
+
+    delete[] testPtr;
+}
+
 #endif
 
 }