blob: 0930d4e8d7a1a734962e1b63cb91161c2dd9e402 [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/BaseMemoryManager.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/ITensorHandleFactory.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010010
11namespace armnn
12{
13
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010014constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; }
15
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010016const std::set<armnn::LayerType> paddingRequiredLayers {
17 LayerType::ArgMinMax,
18 LayerType::Concat,
19 LayerType::Convolution2d,
20 LayerType::DepthToSpace,
21 LayerType::DepthwiseConvolution2d,
22 LayerType::Dequantize,
23 LayerType::FullyConnected,
24 LayerType::Gather,
25 LayerType::L2Normalization,
26 LayerType::Lstm,
27 LayerType::Mean,
28 LayerType::Multiplication,
29 LayerType::Normalization,
30 LayerType::Permute,
31 LayerType::Pooling2d,
32 LayerType::Quantize,
33 LayerType::QuantizedLstm,
34 LayerType::Resize,
35 LayerType::Stack,
36 LayerType::Transpose,
37 LayerType::TransposeConvolution2d
38};
39
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010040class NeonTensorHandleFactory : public ITensorHandleFactory
41{
42public:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010043 NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
James Conroy57d10b72019-10-25 09:44:14 +010044 : m_MemoryManager(mgr)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010045 {}
46
47 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010048 const TensorShape& subTensorShape,
49 const unsigned int* subTensorOrigin) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010050
David Monahanc6e5a6e2019-10-02 09:33:57 +010051 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
52
53 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
54 DataLayout dataLayout) const override;
55
David Monahan3fb7e102019-08-20 11:25:29 +010056 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
57 const bool IsMemoryManaged = true) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010058
59 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010060 DataLayout dataLayout,
61 const bool IsMemoryManaged = true) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010062
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010063 static const FactoryId& GetIdStatic();
64
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010065 const FactoryId& GetId() const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010066
67 bool SupportsSubTensors() const override;
68
69 MemorySourceFlags GetExportFlags() const override;
70
71 MemorySourceFlags GetImportFlags() const override;
72
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010073 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
74 const IConnectableLayer* connectedLayer,
75 CapabilityClass capabilityClass) override;
76
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010077private:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010078 mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010079};
80
81} // namespace armnn