blob: 9f2cdba6ef7e60fbc5c1742451790fd9f29f912c [file] [log] [blame]
Cathal Corbett3883b272022-07-22 16:03:36 +01001//
Mike Kelly4cc341c2023-07-07 15:43:06 +01002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett3883b272022-07-22 16:03:36 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Mike Kelly4cc341c2023-07-07 15:43:06 +01008#include <armnn/StrategyBase.hpp>
9#include <armnn/Descriptors.hpp>
Cathal Corbett3883b272022-07-22 16:03:36 +010010#include <optimizations/FoldPadIntoLayer2d.hpp>
11
12namespace armnn
13{
14
15namespace
16{
17
Mike Kelly4cc341c2023-07-07 15:43:06 +010018/// Checks if a Layer has a DataLayout that is either NCHW or NCDHW.
19class CheckForNCHW : public StrategyBase<NoThrowStrategy>
20{
21public:
22 CheckForNCHW()
23 {}
24
25 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
26 const armnn::BaseDescriptor& descriptor,
27 const std::vector<armnn::ConstTensor>& constants,
28 const char* name,
29 const armnn::LayerBindingId id = 0) override
30 {
31 armnn::IgnoreUnused(layer, constants, id, name);
32 switch (layer->GetType())
33 {
34 case armnn::LayerType::BatchMatMul:
35 {
Mike Kellyb6de7a12023-07-18 12:03:41 +010036 auto desc = static_cast<const armnn::BatchMatMulDescriptor&>(descriptor);
Mike Kelly4cc341c2023-07-07 15:43:06 +010037 m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW;
38 break;
39 }
40 case armnn::LayerType::BatchNormalization:
41 {
42 CheckDescForNCHW(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
43 break;
44 }
45 case armnn::LayerType::BatchToSpaceNd:
46 {
47 CheckDescForNCHW(static_cast<const armnn::BatchToSpaceNdDescriptor&>(descriptor));
48 break;
49 }
50 case armnn::LayerType::Convolution2d:
51 {
52 CheckDescForNCHW(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
53 break;
54 }
55 case armnn::LayerType::Convolution3d:
56 {
57 CheckDescForNCHW(static_cast<const armnn::Convolution3dDescriptor&>(descriptor));
58 break;
59 }
60 case armnn::LayerType::DepthwiseConvolution2d:
61 {
62 CheckDescForNCHW(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
63 break;
64 }
65 case armnn::LayerType::InstanceNormalization:
66 {
67 CheckDescForNCHW(static_cast<const armnn::InstanceNormalizationDescriptor&>(descriptor));
68 break;
69 }
70 case armnn::LayerType::L2Normalization:
71 {
72 CheckDescForNCHW(static_cast<const armnn::L2NormalizationDescriptor&>(descriptor));
73 break;
74 }
75 case armnn::LayerType::Normalization:
76 {
77 CheckDescForNCHW(static_cast<const armnn::NormalizationDescriptor&>(descriptor));
78 break;
79 }
80 case armnn::LayerType::Pooling2d:
81 {
82 CheckDescForNCHW(static_cast<const armnn::Pooling2dDescriptor&>(descriptor));
83 break;
84 }
85 case armnn::LayerType::Pooling3d:
86 {
87 CheckDescForNCHW(static_cast<const armnn::Pooling3dDescriptor&>(descriptor));
88 break;
89 }
90 case armnn::LayerType::SpaceToBatchNd:
91 {
92 CheckDescForNCHW(static_cast<const armnn::SpaceToBatchNdDescriptor&>(descriptor));
93 break;
94 }
95 case armnn::LayerType::SpaceToDepth:
96 {
97 CheckDescForNCHW(static_cast<const armnn::SpaceToDepthDescriptor&>(descriptor));
98 break;
99 }
100 case armnn::LayerType::StridedSlice:
101 {
102 CheckDescForNCHW(static_cast<const armnn::StridedSliceDescriptor&>(descriptor));
103 break;
104 }
105 default:
106 {
107 m_Result = false;
108 }
109 }
110 }
111
112 /// Returns true if the Layer had a DataLayout and it was NCHW or NCDHW.
113 /// Returns false if the Layer either doesn't have a DataLayout or if it
114 /// had a DataLayout that was neither NCHW nor NCDHW.
115 bool Result()
116 {
117 return m_Result;
118 }
119
120private:
121 template<typename Descriptor>
122 void CheckDescForNCHW(const Descriptor& descriptor)
123 {
124 m_Result = (descriptor.m_DataLayout == DataLayout::NCHW) || (descriptor.m_DataLayout == DataLayout::NCDHW);
125 }
126
127 bool m_Result = false;
128};
129
Cathal Corbett3883b272022-07-22 16:03:36 +0100130//
131// this helper only works if all layers where the inputs connect to are not selected
132//
133
134SubgraphView::IInputSlots CreateIInputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
135{
136 SubgraphView::IInputSlots result;
137 for (auto&& layer : layers)
138 {
139 for (unsigned int i = 0 ; i < layer->GetNumInputSlots(); ++i)
140 {
141 result.push_back(&(layer->GetInputSlot(i)));
142 }
143 }
144 return result;
145}
146
147//
148// this helper only works if all layers where the outputs connect to are not selected
149//
150
151SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
152{
153 SubgraphView::IOutputSlots result;
154 for (auto &&layer: layers)
155 {
156 for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
157 {
158 result.push_back(&(layer->GetOutputSlot(i)));
159 }
160 }
161 return result;
162}
163
164}
165
Mike Kelly4cc341c2023-07-07 15:43:06 +0100166inline bool IsNCHW(armnn::Layer& layer)
167{
168 CheckForNCHW check;
169 layer.ExecuteStrategy(check);
170 return check.Result();
171}
172
Cathal Corbett3883b272022-07-22 16:03:36 +0100173inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
174{
175 std::vector<Layer*> untouchedVector;
176 for (const auto& pair : untouched)
177 {
178 Layer* layer = pair.second;
179 SubgraphView subgraphView({layer},
180 CreateIInputsFrom({layer}),
181 CreateIOutputsFrom({layer}));
182 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
183 }
184}
185
186template<typename LayerType>
187LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
188 LayerType* baseLayer,
189 LayerType* replacementLayer,
190 PadLayer* padLayer)
191{
192 SubgraphView substitutionSubgraph({padLayer, baseLayer},
193 CreateIInputsFrom({padLayer}),
194 CreateIOutputsFrom({baseLayer}));
195 SubgraphView replacementSubgraph(replacementLayer);
196
197 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
198
199 return replacementLayer;
200}
201
Mike Kellybe06f102023-07-17 17:49:55 +0100202/// Checks if the Layer is connected to any Layer that has an NCHW layout.
203inline bool ConnectedToLayerWithNCHW(Layer* baseLayer)
204{
205 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
206
207 if (IsNCHW(parentLayer))
208 {
209 return true;
210 }
211 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
212 {
213 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
214 if (IsNCHW(nextLayer))
215 {
216 return true;
217 }
218 }
219 return false;
220}
221
Mike Kellyb6de7a12023-07-18 12:03:41 +0100222/// Checks the Layer's Connections to see if it's connected to a Layer with the provided layerType. If dimSize is
223/// provided will also check if the connecting Tensor has more than that number of dimensions
224inline bool ConnectedToLayerType(Layer* baseLayer, LayerType layerType, unsigned int dimSize = 0)
Mike Kellybe06f102023-07-17 17:49:55 +0100225{
226 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
Mike Kellyb6de7a12023-07-18 12:03:41 +0100227 TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetTensorInfo();
228
229 if (parentTensorInfo.GetNumDimensions() > dimSize && parentLayer.GetType() == layerType)
Mike Kellybe06f102023-07-17 17:49:55 +0100230 {
231 return true;
232 }
233 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
234 {
235 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
236 TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo();
Mike Kellyb6de7a12023-07-18 12:03:41 +0100237
238 if (nextTensorInfo.GetNumDimensions() > dimSize && nextLayer.GetType() == layerType)
Mike Kellybe06f102023-07-17 17:49:55 +0100239 {
240 return true;
241 }
242 }
243 return false;
244}
245
Mike Kelly4cc341c2023-07-07 15:43:06 +0100246inline void RemoveReshapeLayer(ReshapeLayer* baseLayer,
247 std::map<LayerGuid, Layer*>& untouched,
248 OptimizationViews& optimizationViews)
249{
250 if (baseLayer == nullptr)
251 {
252 return;
253 }
254 ReshapeDescriptor reshapeDescriptor = baseLayer->GetParameters();
255 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
256
257 // Cannot currently remove the Reshape if it's connected to an Input, Constant or Splitter
258 if (parentLayer.GetType() == LayerType::Input || parentLayer.GetType() == LayerType::Constant)
259 {
260 return;
261 }
262
263 // Cannot currently remove the Reshape if it's connected to an OutputSlot or Concat
264 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
265 {
266 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
267
268 if (nextLayer.GetType() == LayerType::Output)
269 {
270 return;
271 }
272 }
273 auto it = untouched.find(baseLayer->GetGuid());
274 if (it == untouched.end())
275 {
276 // Already removed from map
277 return;
278 }
279 untouched.erase(it);
280
281 // Override the InputSlot TensorInfos for all the layers connected to the Reshape's OutputSlot
282 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
283 {
284 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
285 auto inputIndex = baseLayer->GetOutputSlot(0).GetConnection(i)->GetSlotIndex();
286 TensorInfo reshapeInfo(baseLayer->GetOutputSlot(0).GetTensorInfo());
287 reshapeInfo.SetShape(reshapeDescriptor.m_TargetShape);
288 nextLayer.GetInputSlot(inputIndex).SetTensorInfo(reshapeInfo);
289 }
290 optimizationViews.AddDeletedSubgraph(baseLayer);
291}
292
Cathal Corbett3883b272022-07-22 16:03:36 +0100293template<typename LayerType>
294LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
295 Pooling2dLayer* baseLayer,
296 Pooling2dDescriptor& poolDescriptor,
297 PadLayer* padLayer)
298{
Mike Kelly4cc341c2023-07-07 15:43:06 +0100299 IConnectableLayer* replacement =
300 optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
301 LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
Cathal Corbett3883b272022-07-22 16:03:36 +0100302
Mike Kelly4cc341c2023-07-07 15:43:06 +0100303 FoldPadLayer(optimizationViews,
304 baseLayer,
305 replacementLayer,
306 padLayer);
Cathal Corbett3883b272022-07-22 16:03:36 +0100307
Mike Kelly4cc341c2023-07-07 15:43:06 +0100308 return replacementLayer;
Cathal Corbett3883b272022-07-22 16:03:36 +0100309}
310
311} // namespace armnn