blob: 4e013a37a1d65444451c9693cf4dd37872161c09 [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonTensorHandleFactory.hpp"
7#include "NeonTensorHandle.hpp"
8
Narumol Prangnawarat1a268962020-07-27 15:52:13 +01009#include "Layer.hpp"
10
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010013
14namespace armnn
15{
16
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010017using FactoryId = ITensorHandleFactory::FactoryId;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010018
19std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010020 const TensorShape& subTensorShape,
21 const unsigned int* subTensorOrigin)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010022 const
23{
24 const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25
26 arm_compute::Coordinates coords;
27 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010028 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010029 {
30 // Arm compute indexes tensor coords in reverse order.
31 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
32 coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex]));
33 }
34
35 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
David Monahan49895f42020-07-21 11:16:51 +010036
37 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
38 // must match the parent shapes
39 if (coords.x() != 0 || coords.y() != 0)
40 {
41 return nullptr;
42 }
43 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
44 {
45 return nullptr;
46 }
47
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010048 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
49 {
50 return nullptr;
51 }
52
53 return std::make_unique<NeonSubTensorHandle>(
Jan Eilersbb446e52020-04-02 13:56:54 +010054 PolymorphicDowncast<IAclTensorHandle*>(&parent), shape, coords);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010055}
56
David Monahanc6e5a6e2019-10-02 09:33:57 +010057std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
58{
59 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
60}
61
62std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
63 DataLayout dataLayout) const
64{
65 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
66}
67
David Monahan3fb7e102019-08-20 11:25:29 +010068std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
69 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010070{
71 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo);
David Monahan3fb7e102019-08-20 11:25:29 +010072 if (IsMemoryManaged)
73 {
74 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
75 }
76 // If we are not Managing the Memory then we must be importing
77 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010078 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010079
80 return tensorHandle;
81}
82
83std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010084 DataLayout dataLayout,
85 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010086{
87 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
David Monahan3fb7e102019-08-20 11:25:29 +010088 if (IsMemoryManaged)
89 {
90 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
91 }
92 // If we are not Managing the Memory then we must be importing
93 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010094 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010095
96 return tensorHandle;
97}
98
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010099const FactoryId& NeonTensorHandleFactory::GetIdStatic()
100{
101 static const FactoryId s_Id(NeonTensorHandleFactoryId());
102 return s_Id;
103}
104
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100105const FactoryId& NeonTensorHandleFactory::GetId() const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100106{
Jan Eilerse9f0f0f2019-08-16 10:28:37 +0100107 return GetIdStatic();
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100108}
109
110bool NeonTensorHandleFactory::SupportsSubTensors() const
111{
112 return true;
113}
114
115MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
116{
James Conroy57d10b72019-10-25 09:44:14 +0100117 return 0;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100118}
119
120MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
121{
James Conroyffab16f2019-11-07 14:37:09 +0000122 return 0;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100123}
124
Narumol Prangnawarat1a268962020-07-27 15:52:13 +0100125std::vector<Capability> NeonTensorHandleFactory::GetCapabilities(const IConnectableLayer* layer,
126 const IConnectableLayer* connectedLayer,
127 CapabilityClass capabilityClass)
128
129{
130 IgnoreUnused(connectedLayer);
131 std::vector<Capability> capabilities;
132 if (capabilityClass == CapabilityClass::PaddingRequired)
133 {
134 auto search = paddingRequiredLayers.find((PolymorphicDowncast<const Layer*>(layer))->GetType());
135 if ( search != paddingRequiredLayers.end())
136 {
137 Capability paddingCapability(CapabilityClass::PaddingRequired, true);
138 capabilities.push_back(paddingCapability);
139 }
140 }
141 return capabilities;
142}
143
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100144} // namespace armnn