blob: 7e22949647dae8dad0fc5e998e6376803172a108 [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <aclCommon/BaseMemoryManager.hpp>
8#include <armnn/MemorySources.hpp>
9#include <armnn/backends/IMemoryManager.hpp>
10#include <armnn/backends/ITensorHandleFactory.hpp>
11
12namespace armnn
13{
14
15constexpr const char* ClImportTensorHandleFactoryId()
16{
17 return "Arm/Cl/ImportTensorHandleFactory";
18}
19
20/**
David Monahane4a41dc2021-04-14 16:55:36 +010021 * This factory creates ClImportTensorHandles that refer to imported memory tensors.
Colm Donelanc74b1752021-03-12 15:58:48 +000022 */
23class ClImportTensorHandleFactory : public ITensorHandleFactory
24{
25public:
26 static const FactoryId m_Id;
27
28 /**
29 * Create a tensor handle factory for tensors that will be imported or exported.
30 *
31 * @param importFlags
32 * @param exportFlags
33 */
34 ClImportTensorHandleFactory(MemorySourceFlags importFlags, MemorySourceFlags exportFlags)
35 : m_ImportFlags(importFlags)
36 , m_ExportFlags(exportFlags)
37 {}
38
39 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
40 const TensorShape& subTensorShape,
41 const unsigned int* subTensorOrigin) const override;
42
43 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
44
45 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
46 DataLayout dataLayout) const override;
47
48 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
49 const bool IsMemoryManaged) const override;
50
51 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
52 DataLayout dataLayout,
53 const bool IsMemoryManaged) const override;
54
55 static const FactoryId& GetIdStatic();
56
57 const FactoryId& GetId() const override;
58
59 bool SupportsSubTensors() const override;
60
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010061 bool SupportsMapUnmap() const override;
62
Colm Donelanc74b1752021-03-12 15:58:48 +000063 MemorySourceFlags GetExportFlags() const override;
64
65 MemorySourceFlags GetImportFlags() const override;
66
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010067 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
68 const IConnectableLayer* connectedLayer,
69 CapabilityClass capabilityClass) override;
70
Colm Donelanc74b1752021-03-12 15:58:48 +000071private:
72 MemorySourceFlags m_ImportFlags;
73 MemorySourceFlags m_ExportFlags;
74};
75
76} // namespace armnn