blob: b22eb52ed350ee8c03fbbb3325efaedbdc3ee40e [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Colm Donelanc74b1752021-03-12 15:58:48 +00007#include <armnn/MemorySources.hpp>
8#include <armnn/backends/IMemoryManager.hpp>
9#include <armnn/backends/ITensorHandleFactory.hpp>
10
11namespace armnn
12{
13
14constexpr const char* ClImportTensorHandleFactoryId()
15{
16 return "Arm/Cl/ImportTensorHandleFactory";
17}
18
19/**
David Monahane4a41dc2021-04-14 16:55:36 +010020 * This factory creates ClImportTensorHandles that refer to imported memory tensors.
Colm Donelanc74b1752021-03-12 15:58:48 +000021 */
22class ClImportTensorHandleFactory : public ITensorHandleFactory
23{
24public:
25 static const FactoryId m_Id;
26
27 /**
28 * Create a tensor handle factory for tensors that will be imported or exported.
29 *
30 * @param importFlags
31 * @param exportFlags
32 */
33 ClImportTensorHandleFactory(MemorySourceFlags importFlags, MemorySourceFlags exportFlags)
34 : m_ImportFlags(importFlags)
35 , m_ExportFlags(exportFlags)
36 {}
37
38 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
39 const TensorShape& subTensorShape,
40 const unsigned int* subTensorOrigin) const override;
41
42 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
43
44 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
45 DataLayout dataLayout) const override;
46
47 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
48 const bool IsMemoryManaged) const override;
49
50 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
51 DataLayout dataLayout,
52 const bool IsMemoryManaged) const override;
53
54 static const FactoryId& GetIdStatic();
55
56 const FactoryId& GetId() const override;
57
58 bool SupportsSubTensors() const override;
59
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010060 bool SupportsMapUnmap() const override;
61
Colm Donelanc74b1752021-03-12 15:58:48 +000062 MemorySourceFlags GetExportFlags() const override;
63
64 MemorySourceFlags GetImportFlags() const override;
65
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +010066 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
67 const IConnectableLayer* connectedLayer,
68 CapabilityClass capabilityClass) override;
69
Colm Donelanc74b1752021-03-12 15:58:48 +000070private:
71 MemorySourceFlags m_ImportFlags;
72 MemorySourceFlags m_ExportFlags;
73};
74
75} // namespace armnn