IVGCVSW-5013 Add TensorHandleFactory to Sample Dynamic Tensor

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I9f8367ebb59a73570a1a2de68aaadba98abef11c
diff --git a/src/dynamic/sample/CMakeLists.txt b/src/dynamic/sample/CMakeLists.txt
index a013771..0a32bf9 100644
--- a/src/dynamic/sample/CMakeLists.txt
+++ b/src/dynamic/sample/CMakeLists.txt
@@ -1,5 +1,5 @@
 #
-# Copyright © 2020 Arm Ltd. All rights reserved.
+# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
 # SPDX-License-Identifier: MIT
 #
 
@@ -26,6 +26,8 @@
         SampleMemoryManager.hpp
         SampleTensorHandle.cpp
         SampleTensorHandle.hpp
+        SampleDynamicTensorHandleFactory.cpp
+        SampleDynamicTensorHandleFactory.hpp
 )
 
 add_library(Arm_SampleDynamic_backend MODULE ${armnnSampleDynamicBackend_sources})
diff --git a/src/dynamic/sample/SampleDynamicBackend.cpp b/src/dynamic/sample/SampleDynamicBackend.cpp
index 19aaaae..2ef8faa 100644
--- a/src/dynamic/sample/SampleDynamicBackend.cpp
+++ b/src/dynamic/sample/SampleDynamicBackend.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2020 Arm Ltd. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -7,6 +7,7 @@
 #include "SampleDynamicLayerSupport.hpp"
 #include "SampleDynamicWorkloadFactory.hpp"
 #include "SampleMemoryManager.hpp"
+#include "SampleDynamicTensorHandleFactory.hpp"
 
 #include <armnn/backends/IBackendInternal.hpp>
 #include <armnn/backends/OptimizationViews.hpp>
@@ -38,7 +39,8 @@
     IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory(
         const IMemoryManagerSharedPtr& memoryManager) const override
     {
-        return std::make_unique<SampleDynamicWorkloadFactory>();
+        return std::make_unique<SampleDynamicWorkloadFactory>(
+                PolymorphicPointerDowncast<SampleMemoryManager>(memoryManager));
     }
 
     IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory(
@@ -61,7 +63,7 @@
 
     std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
     {
-        return std::vector<ITensorHandleFactory::FactoryId>();
+        return std::vector<ITensorHandleFactory::FactoryId> { SampleDynamicTensorHandleFactory::GetIdStatic() };
     }
 
     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
@@ -77,6 +79,15 @@
 
         return optimizationViews;
     }
+
+    void RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) const override
+    {
+        auto memoryManager = std::make_shared<SampleMemoryManager>();
+
+        registry.RegisterMemoryManager(memoryManager);
+        registry.RegisterFactory(std::make_unique<SampleDynamicTensorHandleFactory>(memoryManager));
+    }
+
 };
 
 } // namespace armnn
diff --git a/src/dynamic/sample/SampleDynamicTensorHandleFactory.cpp b/src/dynamic/sample/SampleDynamicTensorHandleFactory.cpp
new file mode 100644
index 0000000..0852bed
--- /dev/null
+++ b/src/dynamic/sample/SampleDynamicTensorHandleFactory.cpp
@@ -0,0 +1,91 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "SampleDynamicTensorHandleFactory.hpp"
+#include "SampleTensorHandle.hpp"
+
+#include <armnn/utility/IgnoreUnused.hpp>
+
+namespace armnn
+{
+
+using FactoryId = ITensorHandleFactory::FactoryId;
+
+const FactoryId& SampleDynamicTensorHandleFactory::GetIdStatic()
+{
+    static const FactoryId s_Id(SampleDynamicTensorHandleFactoryId());
+    return s_Id;
+}
+
+std::unique_ptr<ITensorHandle>
+SampleDynamicTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
+                                                        TensorShape const& subTensorShape,
+                                                        unsigned int const* subTensorOrigin) const
+{
+    IgnoreUnused(parent, subTensorShape, subTensorOrigin);
+    return nullptr;
+}
+
+std::unique_ptr<ITensorHandle> SampleDynamicTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
+{
+    return std::make_unique<SampleTensorHandle>(tensorInfo, m_MemoryManager);
+}
+
+std::unique_ptr<ITensorHandle> SampleDynamicTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                                                    DataLayout dataLayout) const
+{
+    IgnoreUnused(dataLayout);
+    return std::make_unique<SampleTensorHandle>(tensorInfo, m_MemoryManager);
+}
+
+std::unique_ptr<ITensorHandle> SampleDynamicTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                                                    const bool IsMemoryManaged) const
+{
+    if (IsMemoryManaged)
+    {
+        return std::make_unique<SampleTensorHandle>(tensorInfo, m_MemoryManager);
+    }
+    else
+    {
+        return std::make_unique<SampleTensorHandle>(tensorInfo, m_ImportFlags);
+    }
+}
+
+std::unique_ptr<ITensorHandle> SampleDynamicTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                                                    DataLayout dataLayout,
+                                                                                    const bool IsMemoryManaged) const
+{
+    IgnoreUnused(dataLayout);
+    if (IsMemoryManaged)
+    {
+        return std::make_unique<SampleTensorHandle>(tensorInfo, m_MemoryManager);
+    }
+    else
+    {
+        return std::make_unique<SampleTensorHandle>(tensorInfo, m_ImportFlags);
+    }
+}
+
+const FactoryId& SampleDynamicTensorHandleFactory::GetId() const
+{
+    return GetIdStatic();
+}
+
+bool SampleDynamicTensorHandleFactory::SupportsSubTensors() const
+{
+    return false;
+}
+
+MemorySourceFlags SampleDynamicTensorHandleFactory::GetExportFlags() const
+{
+    return m_ExportFlags;
+}
+
+MemorySourceFlags SampleDynamicTensorHandleFactory::GetImportFlags() const
+{
+    return m_ImportFlags;
+}
+
+} // namespace armnn
\ No newline at end of file
diff --git a/src/dynamic/sample/SampleDynamicTensorHandleFactory.hpp b/src/dynamic/sample/SampleDynamicTensorHandleFactory.hpp
new file mode 100644
index 0000000..5f5e880
--- /dev/null
+++ b/src/dynamic/sample/SampleDynamicTensorHandleFactory.hpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "SampleMemoryManager.hpp"
+
+#include <armnn/backends/ITensorHandleFactory.hpp>
+
+namespace armnn
+{
+
+constexpr const char * SampleDynamicTensorHandleFactoryId() { return "Arm/SampleDynamic/TensorHandleFactory"; }
+
+class SampleDynamicTensorHandleFactory : public ITensorHandleFactory
+{
+
+public:
+    SampleDynamicTensorHandleFactory(std::shared_ptr<SampleMemoryManager> mgr)
+    : m_MemoryManager(mgr),
+      m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
+      m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
+    {}
+
+    std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
+                                                         TensorShape const& subTensorShape,
+                                                         unsigned int const* subTensorOrigin) const override;
+
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                      DataLayout dataLayout) const override;
+
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                      const bool IsMemoryManaged) const override;
+
+    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+                                                      DataLayout dataLayout,
+                                                      const bool IsMemoryManaged) const override;
+
+    static const FactoryId& GetIdStatic();
+
+    const FactoryId& GetId() const override;
+
+    bool SupportsSubTensors() const override;
+
+    MemorySourceFlags GetExportFlags() const override;
+
+    MemorySourceFlags GetImportFlags() const override;
+
+private:
+    mutable std::shared_ptr<SampleMemoryManager> m_MemoryManager;
+    MemorySourceFlags m_ImportFlags;
+    MemorySourceFlags m_ExportFlags;
+};
+
+} // namespace armnn
+