blob: b04614b31bd6e738543aaae7335357c65dae7751 [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
Declan-ARM7c75e332024-03-12 16:40:25 +00002// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5#include "SplitterLayer.hpp"
6
7#include "LayerCloneBase.hpp"
8
9#include <armnn/TypesUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000010#include <armnn/backends/WorkloadData.hpp>
11#include <armnn/backends/WorkloadFactory.hpp>
Teresa Charlin7db70892024-04-23 13:43:03 +010012#include <backendsCommon/WorkloadUtils.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010013
14namespace armnn
15{
16
17SplitterLayer::SplitterLayer(const ViewsDescriptor& param, const char* name)
18 : LayerWithParameters(1, param.GetNumViews(), LayerType::Splitter, param, name)
19{
20}
21
Derek Lamberti94a88d22019-12-10 21:12:59 +000022std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const IWorkloadFactory& factory) const
surmeh013537c2c2018-05-18 16:31:43 +010023{
24 SplitterQueueDescriptor descriptor;
25
telsoa01c577f2c2018-08-31 09:22:23 +010026 // Copies the window origins to the descriptor.
surmeh013537c2c2018-05-18 16:31:43 +010027 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
28 {
29 descriptor.m_ViewOrigins.emplace_back(
30 std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions()));
31 }
32
Keith Davisdf04d232020-10-23 17:20:05 +010033 SetAdditionalInfo(descriptor);
34
Teresa Charlin611c7fb2022-01-07 09:47:29 +000035 return factory.CreateWorkload(LayerType::Splitter, descriptor, PrepInfoAndDesc(descriptor));
surmeh013537c2c2018-05-18 16:31:43 +010036}
37
Derek Lamberti84da38b2019-06-13 11:40:08 +010038template<typename FactoryType>
Narumol Prangnawaratef6f3002020-08-17 17:02:12 +010039void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
40 const FactoryType& factory,
41 bool isMemoryManaged)
surmeh013537c2c2018-05-18 16:31:43 +010042{
telsoa01c577f2c2018-08-31 09:22:23 +010043 //If sub tensors are supported than all the "splitter" need to do is to
surmeh013537c2c2018-05-18 16:31:43 +010044 //set the outputs to be appropriate sub tensors of the input.
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +010045 bool useSubTensors = factory.SupportsSubTensors();
46
47 if (useSubTensors)
surmeh013537c2c2018-05-18 16:31:43 +010048 {
Keith Davis3674f142020-08-16 23:44:15 +010049 // Get outputHandler of previous layer
surmeh013537c2c2018-05-18 16:31:43 +010050 const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
Keith Davis3674f142020-08-16 23:44:15 +010051 const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
Mike Kelly7b899922023-07-17 14:17:52 +010052 const TensorInfo& parentInfo = GetInputSlot(0).GetTensorInfo();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +010053
surmeh013537c2c2018-05-18 16:31:43 +010054 ITensorHandle* inputData = outputHandler.GetData();
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +010055
56 std::vector<std::unique_ptr<ITensorHandle>> subTensors;
57
Keith Davis3674f142020-08-16 23:44:15 +010058 // check if split is along the x or y (2 innermost dimensions)
59 auto numberOfDimensions = m_Param.GetNumDimensions();
60
Keith Davis3674f142020-08-16 23:44:15 +010061 std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
62 std::set<unsigned int>::iterator axisIt = axis.begin();
63
64 bool isOnXorY = m_Param.GetNumDimensions() >= 3 &&
65 ((*axisIt == numberOfDimensions - 1) ||
66 (*axisIt == numberOfDimensions - 2));
67
telsoa01c577f2c2018-08-31 09:22:23 +010068 //Creates the outputs as subtensors of the input.
surmeh013537c2c2018-05-18 16:31:43 +010069 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
70 {
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +010071 const TensorInfo& info = m_OutputHandlers[i].GetTensorInfo();
72
Derek Lamberti84da38b2019-06-13 11:40:08 +010073 OutputSlot& outSlot = GetOutputSlot(i);
74 ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
Keith Davis3674f142020-08-16 23:44:15 +010075
76 const unsigned int numOutputSlots = GetNumOutputSlots();
77
78 // if split along x or y (2 innermost dimensions) and the next layers do not require padding
79 bool canUseSubTensorOnXorY = true;
80 bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
81 if (isTensorHandleFactory)
82 {
83 for (unsigned int it = 0; it < numOutputSlots; ++it)
84 {
85 InputSlot* inputSlot = GetOutputSlot(it).GetConnection(0);
86 ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
87 std::vector<Capability> capabilities =
88 handleFactory->GetCapabilities(&(inputSlot->GetOwningLayer()),
89 this,
90 CapabilityClass::PaddingRequired);
91 if (isOnXorY)
92 {
93 canUseSubTensorOnXorY = false;
94 if (capabilities.empty())
95 {
96 canUseSubTensorOnXorY = true;
97 }
98 }
99
100 if (!canUseSubTensorOnXorY)
101 {
102 break;
103 }
104 }
105 }
106
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100107 auto CreateSubTensor = [&]()
108 {
Keith Davis3674f142020-08-16 23:44:15 +0100109 // Make sure:
110 // 1) quantization parameters are in the same space
111 // 2) the same TensorHandleFactory is used for input and split layer output
112 // 3) the output does not go to a Constant layer or input layer
113 // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding
114 if (parentInfo.IsTypeSpaceMatch(info) && //(1)
115 factoryId == slot->GetTensorHandleFactoryId() && //(2)
116 GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3)
117 GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3)
Mike Kellya638f102023-07-24 17:42:47 +0100118 canUseSubTensorOnXorY) //(4)
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100119 {
Teresa Charlinec01fb72020-08-16 23:40:14 +0100120 ARMNN_NO_DEPRECATE_WARN_BEGIN
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100121 return factory.CreateSubTensorHandle(*inputData,
122 info.GetShape(),
123 this->m_Param.GetViewOrigin(i));
Teresa Charlinec01fb72020-08-16 23:40:14 +0100124 ARMNN_NO_DEPRECATE_WARN_END
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100125 }
126 return std::unique_ptr<ITensorHandle>();
127 };
128
129 auto subTensor = CreateSubTensor();
130 if (!subTensor)
131 {
132 useSubTensors = false;
133 break; //Failed to create a valid sub-tensor, so stop trying with the rest of the views.
134 }
135 subTensors.push_back(std::move(subTensor));
136 }
137
138 if (useSubTensors)
139 {
140 unsigned int i = 0;
141 for (auto& subTensor : subTensors)
142 {
143 m_OutputHandlers[i].SetData(std::move(subTensor));
144 ++i;
145 }
surmeh013537c2c2018-05-18 16:31:43 +0100146 }
147 }
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100148
149 if (!useSubTensors)
surmeh013537c2c2018-05-18 16:31:43 +0100150 {
151 for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
152 {
Narumol Prangnawaratef6f3002020-08-17 17:02:12 +0100153 m_OutputHandlers[i].CreateTensorHandles(factory, isMemoryManaged);
surmeh013537c2c2018-05-18 16:31:43 +0100154 }
155 }
156}
157
Derek Lamberti84da38b2019-06-13 11:40:08 +0100158void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
David Monahan3fb7e102019-08-20 11:25:29 +0100159 const IWorkloadFactory& workloadFactory,
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100160 const bool isMemoryManaged)
Derek Lamberti84da38b2019-06-13 11:40:08 +0100161{
162 OutputSlot& slot = GetOutputSlot(0);
163 ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId();
164
165 if (factoryId == ITensorHandleFactory::LegacyFactoryId)
166 {
Narumol Prangnawaratef6f3002020-08-17 17:02:12 +0100167 CreateTensors(registry, workloadFactory, isMemoryManaged);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100168 }
169 else
170 {
Narumol Prangnawarate5f0b242021-05-07 17:52:36 +0100171 ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
Declan-ARM7c75e332024-03-12 16:40:25 +0000172 if (!handleFactory)
173 {
174 throw armnn::NullPointerException("handleFactory is returning a nullptr.");
175 }
Narumol Prangnawaratef6f3002020-08-17 17:02:12 +0100176 CreateTensors(registry, *handleFactory, isMemoryManaged);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100177 }
178}
179
surmeh013537c2c2018-05-18 16:31:43 +0100180SplitterLayer* SplitterLayer::Clone(Graph& graph) const
181{
182 return CloneBase<SplitterLayer>(graph, m_Param, GetName());
183}
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
surmeh013537c2c2018-05-18 16:31:43 +0100186{
Declan-ARM7c75e332024-03-12 16:40:25 +0000187 if (inputShapes.size() != m_Param.GetNumViews())
188 {
189 throw armnn::Exception("inputShapes' and m_NumViews' sizes do not match (\""
190 + std::to_string(inputShapes.size()) +
191 "\" vs \""
192 + std::to_string(m_Param.GetNumViews()) + "\")");
193 }
194
telsoa01c577f2c2018-08-31 09:22:23 +0100195 std::vector<TensorShape> outShapes;
surmeh013537c2c2018-05-18 16:31:43 +0100196 //Output shapes must match View shapes.
197 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
198 {
199 const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
telsoa01c577f2c2018-08-31 09:22:23 +0100200 outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
201 }
202 return outShapes;
203}
surmeh013537c2c2018-05-18 16:31:43 +0100204
Finn Williamsf24effa2020-07-03 10:12:03 +0100205void SplitterLayer::ValidateTensorShapesFromInputs()
telsoa01c577f2c2018-08-31 09:22:23 +0100206{
Finn Williams87d0bda2020-07-03 10:12:03 +0100207 std::for_each(BeginOutputSlots(), EndOutputSlots(), [&](OutputSlot& outputSlot)
208 {
Finn Williamsf24effa2020-07-03 10:12:03 +0100209 VerifyShapeInferenceType(outputSlot.GetTensorInfo().GetShape(), m_ShapeInferenceMethod);
Finn Williams87d0bda2020-07-03 10:12:03 +0100210 });
Teresa Charlincdc01492020-06-09 18:00:20 +0100211
telsoa01c577f2c2018-08-31 09:22:23 +0100212 std::vector<TensorShape> views;
213 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
214 {
215 const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
216 views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
217 }
218
219 auto inferredShapes = InferOutputShapes(views);
220
Declan-ARM7c75e332024-03-12 16:40:25 +0000221 if (inferredShapes.size() != m_Param.GetNumViews())
222 {
223 throw armnn::LayerValidationException("inferredShapes' size and m_NumViews do not match (\""
224 + std::to_string(inferredShapes.size()) +
225 "\" vs \""
226 + std::to_string(m_Param.GetNumViews()) + "\")");
227 }
telsoa01c577f2c2018-08-31 09:22:23 +0100228
229 for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
230 {
Finn Williams87d0bda2020-07-03 10:12:03 +0100231 ValidateAndCopyShape(GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
232 inferredShapes[viewIdx],
Finn Williamsf24effa2020-07-03 10:12:03 +0100233 m_ShapeInferenceMethod,
Finn Williams87d0bda2020-07-03 10:12:03 +0100234 "SplitterLayer",
235 viewIdx);
surmeh013537c2c2018-05-18 16:31:43 +0100236 }
237}
238
Nikhil Raj4d2eec02022-05-30 11:08:52 +0100239void SplitterLayer::ExecuteStrategy(IStrategy& strategy) const
jimfly01e9e7bfd2019-01-24 22:29:33 +0000240{
Nikhil Raj4d2eec02022-05-30 11:08:52 +0100241 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
jimfly01e9e7bfd2019-01-24 22:29:33 +0000242}
243
surmeh013537c2c2018-05-18 16:31:43 +0100244} // namespace armnn