blob: 1dd83950cdfc1aad57b1b2062cfc98f4939ab65a [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonTensorHandleFactory.hpp"
7#include "NeonTensorHandle.hpp"
8
Narumol Prangnawarat1a268962020-07-27 15:52:13 +01009#include "Layer.hpp"
10
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010013
14namespace armnn
15{
16
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010017using FactoryId = ITensorHandleFactory::FactoryId;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010018
19std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010020 const TensorShape& subTensorShape,
21 const unsigned int* subTensorOrigin)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010022 const
23{
24 const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25
26 arm_compute::Coordinates coords;
27 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010028 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010029 {
30 // Arm compute indexes tensor coords in reverse order.
31 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
32 coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex]));
33 }
34
35 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
David Monahan49895f42020-07-21 11:16:51 +010036
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010037 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
38 {
39 return nullptr;
40 }
41
42 return std::make_unique<NeonSubTensorHandle>(
Jan Eilersbb446e52020-04-02 13:56:54 +010043 PolymorphicDowncast<IAclTensorHandle*>(&parent), shape, coords);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010044}
45
David Monahanc6e5a6e2019-10-02 09:33:57 +010046std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
47{
48 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
49}
50
51std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
52 DataLayout dataLayout) const
53{
54 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
55}
56
David Monahan3fb7e102019-08-20 11:25:29 +010057std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
58 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010059{
60 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo);
David Monahan3fb7e102019-08-20 11:25:29 +010061 if (IsMemoryManaged)
62 {
63 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
64 }
65 // If we are not Managing the Memory then we must be importing
66 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010067 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010068
69 return tensorHandle;
70}
71
72std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010073 DataLayout dataLayout,
74 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010075{
76 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
David Monahan3fb7e102019-08-20 11:25:29 +010077 if (IsMemoryManaged)
78 {
79 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
80 }
81 // If we are not Managing the Memory then we must be importing
82 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010083 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010084
85 return tensorHandle;
86}
87
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010088const FactoryId& NeonTensorHandleFactory::GetIdStatic()
89{
90 static const FactoryId s_Id(NeonTensorHandleFactoryId());
91 return s_Id;
92}
93
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010094const FactoryId& NeonTensorHandleFactory::GetId() const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010095{
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010096 return GetIdStatic();
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010097}
98
Sadik Armaganab3bd4d2020-08-25 11:48:00 +010099bool NeonTensorHandleFactory::SupportsInPlaceComputation() const
100{
101 return true;
102}
103
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100104bool NeonTensorHandleFactory::SupportsSubTensors() const
105{
106 return true;
107}
108
109MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
110{
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +0100111 return m_ExportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100112}
113
114MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
115{
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +0100116 return m_ImportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100117}
118
Narumol Prangnawarat1a268962020-07-27 15:52:13 +0100119std::vector<Capability> NeonTensorHandleFactory::GetCapabilities(const IConnectableLayer* layer,
120 const IConnectableLayer* connectedLayer,
121 CapabilityClass capabilityClass)
122
123{
124 IgnoreUnused(connectedLayer);
125 std::vector<Capability> capabilities;
126 if (capabilityClass == CapabilityClass::PaddingRequired)
127 {
128 auto search = paddingRequiredLayers.find((PolymorphicDowncast<const Layer*>(layer))->GetType());
129 if ( search != paddingRequiredLayers.end())
130 {
131 Capability paddingCapability(CapabilityClass::PaddingRequired, true);
132 capabilities.push_back(paddingCapability);
133 }
134 }
135 return capabilities;
136}
137
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100138} // namespace armnn