blob: aab3faad0add1b3abdacdeea2056192e362a9537 [file] [log] [blame]
Jan Eilerse9f0f0f2019-08-16 10:28:37 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00007#include <armnn/backends/ITensorHandleFactory.hpp>
Jan Eilerse9f0f0f2019-08-16 10:28:37 +01008#include <aclCommon/BaseMemoryManager.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/IMemoryManager.hpp>
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010010#include <armnn/MemorySources.hpp>
11
12namespace armnn
13{
14
15constexpr const char* ClTensorHandleFactoryId() { return "Arm/Cl/TensorHandleFactory"; }
16
17class ClTensorHandleFactory : public ITensorHandleFactory {
18public:
19 static const FactoryId m_Id;
20
21 ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr)
22 : m_MemoryManager(mgr),
23 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
24 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
25 {}
26
27 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
28 const TensorShape& subTensorShape,
29 const unsigned int* subTensorOrigin) const override;
30
David Monahanc6e5a6e2019-10-02 09:33:57 +010031 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
32
33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
34 DataLayout dataLayout) const override;
35
David Monahan3fb7e102019-08-20 11:25:29 +010036 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
37 const bool IsMemoryManaged = true) const override;
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010038
39 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010040 DataLayout dataLayout,
41 const bool IsMemoryManaged = true) const override;
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010042
43 static const FactoryId& GetIdStatic();
44
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010045 const FactoryId& GetId() const override;
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010046
47 bool SupportsSubTensors() const override;
48
49 MemorySourceFlags GetExportFlags() const override;
50
51 MemorySourceFlags GetImportFlags() const override;
52
53private:
54 mutable std::shared_ptr<ClMemoryManager> m_MemoryManager;
55 MemorySourceFlags m_ImportFlags;
56 MemorySourceFlags m_ExportFlags;
57};
58
59} // namespace armnn