blob: 005f2f6963c4f71e4c5663bb9bfe3c6851ddba41 [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/BaseMemoryManager.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/ITensorHandleFactory.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010010
11namespace armnn
12{
13
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010014constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; }
15
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010016const std::set<armnn::LayerType> paddingRequiredLayers {
17 LayerType::ArgMinMax,
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010018 LayerType::Convolution2d,
19 LayerType::DepthToSpace,
20 LayerType::DepthwiseConvolution2d,
21 LayerType::Dequantize,
22 LayerType::FullyConnected,
23 LayerType::Gather,
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010024 LayerType::Lstm,
25 LayerType::Mean,
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010026 LayerType::Permute,
27 LayerType::Pooling2d,
28 LayerType::Quantize,
29 LayerType::QuantizedLstm,
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010030 LayerType::Stack,
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010031 LayerType::TransposeConvolution2d
32};
33
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010034class NeonTensorHandleFactory : public ITensorHandleFactory
35{
36public:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010037 NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +010038 : m_MemoryManager(mgr),
39 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
40 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010041 {}
42
43 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010044 const TensorShape& subTensorShape,
45 const unsigned int* subTensorOrigin) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010046
David Monahanc6e5a6e2019-10-02 09:33:57 +010047 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
48
49 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50 DataLayout dataLayout) const override;
51
David Monahan3fb7e102019-08-20 11:25:29 +010052 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
Francis Murtagh623069d2020-08-14 17:24:39 +010053 const bool IsMemoryManaged) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010054
55 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010056 DataLayout dataLayout,
57 const bool IsMemoryManaged = true) const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010058
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010059 static const FactoryId& GetIdStatic();
60
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010061 const FactoryId& GetId() const override;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010062
Sadik Armaganab3bd4d2020-08-25 11:48:00 +010063 bool SupportsInPlaceComputation() const override;
64
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010065 bool SupportsSubTensors() const override;
66
67 MemorySourceFlags GetExportFlags() const override;
68
69 MemorySourceFlags GetImportFlags() const override;
70
Narumol Prangnawarat1a268962020-07-27 15:52:13 +010071 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
72 const IConnectableLayer* connectedLayer,
73 CapabilityClass capabilityClass) override;
74
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010075private:
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010076 mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +010077 MemorySourceFlags m_ImportFlags;
78 MemorySourceFlags m_ExportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010079};
80
81} // namespace armnn