blob: 220e6fd0de73f866cf1af25c264fff2346a1ea09 [file] [log] [blame]
Ferran Balaguerbfeb2712019-08-07 15:14:56 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "RefMemoryManager.hpp"
9
10#include <backendsCommon/ITensorHandleFactory.hpp>
11
12namespace armnn
13{
14
15constexpr const char * RefTensorHandleFactoryId() { return "Arm/Ref/TensorHandleFactory"; }
16
17class RefTensorHandleFactory : public ITensorHandleFactory
18{
19
20public:
21 RefTensorHandleFactory(std::shared_ptr<RefMemoryManager> mgr)
22 : m_MemoryManager(mgr),
23 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
24 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
25 {}
26
27 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
28 TensorShape const& subTensorShape,
29 unsigned int const* subTensorOrigin) const override;
30
David Monahanc6e5a6e2019-10-02 09:33:57 +010031 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010032
33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahanc6e5a6e2019-10-02 09:33:57 +010034 DataLayout dataLayout) const override;
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010035
36 static const FactoryId& GetIdStatic();
37
38 const FactoryId& GetId() const override;
39
40 bool SupportsSubTensors() const override;
41
42 MemorySourceFlags GetExportFlags() const override;
43
44 MemorySourceFlags GetImportFlags() const override;
45
46private:
47 mutable std::shared_ptr<RefMemoryManager> m_MemoryManager;
48 MemorySourceFlags m_ImportFlags;
49 MemorySourceFlags m_ExportFlags;
50
51};
52
53} // namespace armnn
54