blob: 2ca67c9d6e3bbfe2ee5522ce2fee3d49fea7d1da [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/BaseMemoryManager.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/ITensorHandleFactory.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010010
11namespace armnn
12{
13
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010014constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; }
15
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010016const std::set<armnn::LayerType> paddingRequiredLayers {
17 LayerType::ArgMinMax,
18 LayerType::Concat,
19 LayerType::Convolution2d,
20 LayerType::DepthToSpace,
21 LayerType::DepthwiseConvolution2d,
22 LayerType::Dequantize,
23 LayerType::FullyConnected,
24 LayerType::Gather,
25 LayerType::L2Normalization,
26 LayerType::Lstm,
27 LayerType::Mean,
28 LayerType::Multiplication,
29 LayerType::Normalization,
30 LayerType::Permute,
31 LayerType::Pooling2d,
32 LayerType::Quantize,
33 LayerType::QuantizedLstm,
34 LayerType::Resize,
35 LayerType::Stack,
36 LayerType::Transpose,
37 LayerType::TransposeConvolution2d
38};
39
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010040class NeonTensorHandleFactory : public ITensorHandleFactory
41{
42public:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010043 NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +010044 : m_MemoryManager(mgr),
45 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
46 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010047 {}
48
49 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010050 const TensorShape& subTensorShape,
51 const unsigned int* subTensorOrigin) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010052
David Monahanc6e5a6e2019-10-02 09:33:57 +010053 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
54
55 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
56 DataLayout dataLayout) const override;
57
David Monahan3fb7e102019-08-20 11:25:29 +010058 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
Francis Murtagh623069d2020-08-14 17:24:39 +010059 const bool IsMemoryManaged) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010060
61 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010062 DataLayout dataLayout,
63 const bool IsMemoryManaged = true) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010064
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010065 static const FactoryId& GetIdStatic();
66
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010067 const FactoryId& GetId() const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010068
69 bool SupportsSubTensors() const override;
70
71 MemorySourceFlags GetExportFlags() const override;
72
73 MemorySourceFlags GetImportFlags() const override;
74
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010075 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
76 const IConnectableLayer* connectedLayer,
77 CapabilityClass capabilityClass) override;
78
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010079private:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010080 mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +010081 MemorySourceFlags m_ImportFlags;
82 MemorySourceFlags m_ExportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010083};
84
85} // namespace armnn