blob: ce3ce5c0d7128d22ebfb333fa5cb1c0924794309 [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonTensorHandleFactory.hpp"
7#include "NeonTensorHandle.hpp"
8
Narumol Prangnawarat1a268962020-07-27 15:52:13 +01009#include "Layer.hpp"
10
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010012#include <armnn/utility/NumericCast.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010013#include <armnn/utility/PolymorphicDowncast.hpp>
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010014
15namespace armnn
16{
17
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010018using FactoryId = ITensorHandleFactory::FactoryId;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010019
20std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010021 const TensorShape& subTensorShape,
22 const unsigned int* subTensorOrigin)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010023 const
24{
25 const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
26
27 arm_compute::Coordinates coords;
28 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010029 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010030 {
31 // Arm compute indexes tensor coords in reverse order.
32 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
Matthew Sloyan171214c2020-09-09 09:07:37 +010033 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010034 }
35
36 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
David Monahan49895f42020-07-21 11:16:51 +010037
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010038 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
39 {
40 return nullptr;
41 }
42
43 return std::make_unique<NeonSubTensorHandle>(
Jan Eilersbb446e52020-04-02 13:56:54 +010044 PolymorphicDowncast<IAclTensorHandle*>(&parent), shape, coords);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010045}
46
David Monahanc6e5a6e2019-10-02 09:33:57 +010047std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
48{
49 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
50}
51
52std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
53 DataLayout dataLayout) const
54{
55 return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
56}
57
David Monahan3fb7e102019-08-20 11:25:29 +010058std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
59 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010060{
61 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo);
David Monahan3fb7e102019-08-20 11:25:29 +010062 if (IsMemoryManaged)
63 {
64 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
65 }
66 // If we are not Managing the Memory then we must be importing
67 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010068 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010069
70 return tensorHandle;
71}
72
73std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010074 DataLayout dataLayout,
75 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010076{
77 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
David Monahan3fb7e102019-08-20 11:25:29 +010078 if (IsMemoryManaged)
79 {
80 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
81 }
82 // If we are not Managing the Memory then we must be importing
83 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
James Conroy57d10b72019-10-25 09:44:14 +010084 tensorHandle->SetImportFlags(GetImportFlags());
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010085
86 return tensorHandle;
87}
88
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010089const FactoryId& NeonTensorHandleFactory::GetIdStatic()
90{
91 static const FactoryId s_Id(NeonTensorHandleFactoryId());
92 return s_Id;
93}
94
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010095const FactoryId& NeonTensorHandleFactory::GetId() const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010096{
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010097 return GetIdStatic();
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010098}
99
Sadik Armaganab3bd4d2020-08-25 11:48:00 +0100100bool NeonTensorHandleFactory::SupportsInPlaceComputation() const
101{
102 return true;
103}
104
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100105bool NeonTensorHandleFactory::SupportsSubTensors() const
106{
107 return true;
108}
109
110MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
111{
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +0100112 return m_ExportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100113}
114
115MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
116{
Narumol Prangnawaratb8d771a2020-08-14 11:51:12 +0100117 return m_ImportFlags;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +0100118}
119
Narumol Prangnawarat1a268962020-07-27 15:52:13 +0100120std::vector<Capability> NeonTensorHandleFactory::GetCapabilities(const IConnectableLayer* layer,
121 const IConnectableLayer* connectedLayer,
122 CapabilityClass capabilityClass)
123
124{
125 IgnoreUnused(connectedLayer);
126 std::vector<Capability> capabilities;
127 if (capabilityClass == CapabilityClass::PaddingRequired)
128 {
129 auto search = paddingRequiredLayers.find((PolymorphicDowncast<const Layer*>(layer))->GetType());
130 if ( search != paddingRequiredLayers.end())
131 {
132 Capability paddingCapability(CapabilityClass::PaddingRequired, true);
133 capabilities.push_back(paddingCapability);
134 }
135 }
136 return capabilities;
137}
138
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100139} // namespace armnn