blob: 6a9e8f1b766e108957bd57e048a3e5f3798aa796 [file] [log] [blame]
Cathal Corbett3883b272022-07-22 16:03:36 +01001//
Mike Kelly4cc341c2023-07-07 15:43:06 +01002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett3883b272022-07-22 16:03:36 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Mike Kelly4cc341c2023-07-07 15:43:06 +01008#include <armnn/StrategyBase.hpp>
9#include <armnn/Descriptors.hpp>
Cathal Corbett3883b272022-07-22 16:03:36 +010010#include <optimizations/FoldPadIntoLayer2d.hpp>
11
12namespace armnn
13{
14
15namespace
16{
17
Mike Kelly4cc341c2023-07-07 15:43:06 +010018/// Checks if a Layer has a DataLayout that is either NCHW or NCDHW.
19class CheckForNCHW : public StrategyBase<NoThrowStrategy>
20{
21public:
22 CheckForNCHW()
23 {}
24
25 void ExecuteStrategy(const armnn::IConnectableLayer* layer,
26 const armnn::BaseDescriptor& descriptor,
27 const std::vector<armnn::ConstTensor>& constants,
28 const char* name,
29 const armnn::LayerBindingId id = 0) override
30 {
31 armnn::IgnoreUnused(layer, constants, id, name);
32 switch (layer->GetType())
33 {
34 case armnn::LayerType::BatchMatMul:
35 {
Mike Kellyb6de7a12023-07-18 12:03:41 +010036 auto desc = static_cast<const armnn::BatchMatMulDescriptor&>(descriptor);
Mike Kelly4cc341c2023-07-07 15:43:06 +010037 m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW;
38 break;
39 }
40 case armnn::LayerType::BatchNormalization:
41 {
42 CheckDescForNCHW(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
43 break;
44 }
45 case armnn::LayerType::BatchToSpaceNd:
46 {
47 CheckDescForNCHW(static_cast<const armnn::BatchToSpaceNdDescriptor&>(descriptor));
48 break;
49 }
50 case armnn::LayerType::Convolution2d:
51 {
52 CheckDescForNCHW(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
53 break;
54 }
55 case armnn::LayerType::Convolution3d:
56 {
57 CheckDescForNCHW(static_cast<const armnn::Convolution3dDescriptor&>(descriptor));
58 break;
59 }
60 case armnn::LayerType::DepthwiseConvolution2d:
61 {
62 CheckDescForNCHW(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
63 break;
64 }
65 case armnn::LayerType::InstanceNormalization:
66 {
67 CheckDescForNCHW(static_cast<const armnn::InstanceNormalizationDescriptor&>(descriptor));
68 break;
69 }
70 case armnn::LayerType::L2Normalization:
71 {
72 CheckDescForNCHW(static_cast<const armnn::L2NormalizationDescriptor&>(descriptor));
73 break;
74 }
75 case armnn::LayerType::Normalization:
76 {
77 CheckDescForNCHW(static_cast<const armnn::NormalizationDescriptor&>(descriptor));
78 break;
79 }
80 case armnn::LayerType::Pooling2d:
81 {
82 CheckDescForNCHW(static_cast<const armnn::Pooling2dDescriptor&>(descriptor));
83 break;
84 }
85 case armnn::LayerType::Pooling3d:
86 {
87 CheckDescForNCHW(static_cast<const armnn::Pooling3dDescriptor&>(descriptor));
88 break;
89 }
90 case armnn::LayerType::SpaceToBatchNd:
91 {
92 CheckDescForNCHW(static_cast<const armnn::SpaceToBatchNdDescriptor&>(descriptor));
93 break;
94 }
95 case armnn::LayerType::SpaceToDepth:
96 {
97 CheckDescForNCHW(static_cast<const armnn::SpaceToDepthDescriptor&>(descriptor));
98 break;
99 }
100 case armnn::LayerType::StridedSlice:
101 {
102 CheckDescForNCHW(static_cast<const armnn::StridedSliceDescriptor&>(descriptor));
103 break;
104 }
105 default:
106 {
107 m_Result = false;
108 }
109 }
110 }
111
112 /// Returns true if the Layer had a DataLayout and it was NCHW or NCDHW.
113 /// Returns false if the Layer either doesn't have a DataLayout or if it
114 /// had a DataLayout that was neither NCHW nor NCDHW.
115 bool Result()
116 {
117 return m_Result;
118 }
119
120private:
121 template<typename Descriptor>
122 void CheckDescForNCHW(const Descriptor& descriptor)
123 {
124 m_Result = (descriptor.m_DataLayout == DataLayout::NCHW) || (descriptor.m_DataLayout == DataLayout::NCDHW);
125 }
126
127 bool m_Result = false;
128};
129
Cathal Corbett3883b272022-07-22 16:03:36 +0100130//
131// this helper only works if all layers where the inputs connect to are not selected
132//
133
134SubgraphView::IInputSlots CreateIInputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
135{
136 SubgraphView::IInputSlots result;
137 for (auto&& layer : layers)
138 {
139 for (unsigned int i = 0 ; i < layer->GetNumInputSlots(); ++i)
140 {
141 result.push_back(&(layer->GetInputSlot(i)));
142 }
143 }
144 return result;
145}
146
147//
148// this helper only works if all layers where the outputs connect to are not selected
149//
150
151SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
152{
153 SubgraphView::IOutputSlots result;
154 for (auto &&layer: layers)
155 {
156 for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
157 {
158 result.push_back(&(layer->GetOutputSlot(i)));
159 }
160 }
161 return result;
162}
163
Tracy Narine6440ce82023-09-20 14:19:07 +0100164// Type used to hold the slot numbers to create the lists from. There should
165// be a SlotList for each layer in the layers list
166typedef std::vector<int> SlotList;
167
168template<typename ILayerType>
169SubgraphView::IInputSlots CreateIInputsFromSlotLists(const std::vector<ILayerType*>& layers,
170 const std::vector<SlotList>& layersSlotLists)
171{
172 ARMNN_THROW_INVALIDARG_IF_FALSE(layersSlotLists.size() == layers.size());
173
174 SubgraphView::IInputSlots result;
175
176 for (unsigned int layerIdx = 0; layerIdx < layers.size(); ++layerIdx)
177 {
178 const SlotList& slotList = layersSlotLists[layerIdx];
179 for (unsigned int slotIdx = 0 ; slotIdx < layers[layerIdx]->GetNumInputSlots(); ++slotIdx)
180 {
181 if (std::find(slotList.begin(), slotList.end(), slotIdx) != slotList.end())
182 {
183 result.push_back(&(layers[layerIdx]->GetInputSlot(slotIdx)));
184 }
185 }
186 }
187 return result;
188}
189
190template<typename ILayerType>
191SubgraphView::IOutputSlots CreateIOutputsFromSlotLists(const std::vector<ILayerType*>& layers,
192 const std::vector<SlotList>& layersSlotLists)
193{
194 ARMNN_THROW_INVALIDARG_IF_FALSE(layersSlotLists.size() == layers.size());
195
196 SubgraphView::IOutputSlots result;
197 for (unsigned int layerIdx = 0; layerIdx < layers.size(); ++layerIdx)
198 {
199 const SlotList& slotList = layersSlotLists[layerIdx];
200 for (unsigned int slotIdx = 0; slotIdx < layers[layerIdx]->GetNumOutputSlots(); ++slotIdx)
201 {
202 bool foundIt = std::find(slotList.begin(), slotList.end(), slotIdx) != slotList.end();
203 if (foundIt)
204 {
205 result.push_back(&(layers[layerIdx]->GetOutputSlot(slotIdx)));
206 }
207 }
208 }
209 return result;
210}
Cathal Corbett3883b272022-07-22 16:03:36 +0100211}
212
Mike Kelly4cc341c2023-07-07 15:43:06 +0100213inline bool IsNCHW(armnn::Layer& layer)
214{
215 CheckForNCHW check;
216 layer.ExecuteStrategy(check);
217 return check.Result();
218}
219
Cathal Corbett3883b272022-07-22 16:03:36 +0100220inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
221{
222 std::vector<Layer*> untouchedVector;
223 for (const auto& pair : untouched)
224 {
225 Layer* layer = pair.second;
226 SubgraphView subgraphView({layer},
227 CreateIInputsFrom({layer}),
228 CreateIOutputsFrom({layer}));
229 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
230 }
231}
232
233template<typename LayerType>
234LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
235 LayerType* baseLayer,
236 LayerType* replacementLayer,
237 PadLayer* padLayer)
238{
239 SubgraphView substitutionSubgraph({padLayer, baseLayer},
240 CreateIInputsFrom({padLayer}),
241 CreateIOutputsFrom({baseLayer}));
242 SubgraphView replacementSubgraph(replacementLayer);
243
244 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
245
246 return replacementLayer;
247}
248
Mike Kellybe06f102023-07-17 17:49:55 +0100249/// Checks if the Layer is connected to any Layer that has an NCHW layout.
250inline bool ConnectedToLayerWithNCHW(Layer* baseLayer)
251{
252 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
253
254 if (IsNCHW(parentLayer))
255 {
256 return true;
257 }
258 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
259 {
260 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
261 if (IsNCHW(nextLayer))
262 {
263 return true;
264 }
265 }
266 return false;
267}
268
Mike Kellyb6de7a12023-07-18 12:03:41 +0100269/// Checks the Layer's Connections to see if it's connected to a Layer with the provided layerType. If dimSize is
270/// provided will also check if the connecting Tensor has more than that number of dimensions
271inline bool ConnectedToLayerType(Layer* baseLayer, LayerType layerType, unsigned int dimSize = 0)
Mike Kellybe06f102023-07-17 17:49:55 +0100272{
273 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
Mike Kellyb6de7a12023-07-18 12:03:41 +0100274 TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetTensorInfo();
275
276 if (parentTensorInfo.GetNumDimensions() > dimSize && parentLayer.GetType() == layerType)
Mike Kellybe06f102023-07-17 17:49:55 +0100277 {
278 return true;
279 }
280 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
281 {
282 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
283 TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo();
Mike Kellyb6de7a12023-07-18 12:03:41 +0100284
285 if (nextTensorInfo.GetNumDimensions() > dimSize && nextLayer.GetType() == layerType)
Mike Kellybe06f102023-07-17 17:49:55 +0100286 {
287 return true;
288 }
289 }
290 return false;
291}
292
Mike Kelly4cc341c2023-07-07 15:43:06 +0100293inline void RemoveReshapeLayer(ReshapeLayer* baseLayer,
294 std::map<LayerGuid, Layer*>& untouched,
295 OptimizationViews& optimizationViews)
296{
297 if (baseLayer == nullptr)
298 {
299 return;
300 }
301 ReshapeDescriptor reshapeDescriptor = baseLayer->GetParameters();
302 Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
303
304 // Cannot currently remove the Reshape if it's connected to an Input, Constant or Splitter
305 if (parentLayer.GetType() == LayerType::Input || parentLayer.GetType() == LayerType::Constant)
306 {
307 return;
308 }
309
310 // Cannot currently remove the Reshape if it's connected to an OutputSlot or Concat
311 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
312 {
313 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
314
315 if (nextLayer.GetType() == LayerType::Output)
316 {
317 return;
318 }
319 }
320 auto it = untouched.find(baseLayer->GetGuid());
321 if (it == untouched.end())
322 {
323 // Already removed from map
324 return;
325 }
326 untouched.erase(it);
327
328 // Override the InputSlot TensorInfos for all the layers connected to the Reshape's OutputSlot
329 for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
330 {
331 Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
332 auto inputIndex = baseLayer->GetOutputSlot(0).GetConnection(i)->GetSlotIndex();
333 TensorInfo reshapeInfo(baseLayer->GetOutputSlot(0).GetTensorInfo());
334 reshapeInfo.SetShape(reshapeDescriptor.m_TargetShape);
335 nextLayer.GetInputSlot(inputIndex).SetTensorInfo(reshapeInfo);
336 }
337 optimizationViews.AddDeletedSubgraph(baseLayer);
338}
339
Cathal Corbett3883b272022-07-22 16:03:36 +0100340template<typename LayerType>
341LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
342 Pooling2dLayer* baseLayer,
343 Pooling2dDescriptor& poolDescriptor,
344 PadLayer* padLayer)
345{
Mike Kelly4cc341c2023-07-07 15:43:06 +0100346 IConnectableLayer* replacement =
347 optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
348 LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
Cathal Corbett3883b272022-07-22 16:03:36 +0100349
Mike Kelly4cc341c2023-07-07 15:43:06 +0100350 FoldPadLayer(optimizationViews,
351 baseLayer,
352 replacementLayer,
353 padLayer);
Cathal Corbett3883b272022-07-22 16:03:36 +0100354
Mike Kelly4cc341c2023-07-07 15:43:06 +0100355 return replacementLayer;
Cathal Corbett3883b272022-07-22 16:03:36 +0100356}
357
Tracy Narine6440ce82023-09-20 14:19:07 +0100358//
359// Layer sequence detection such as add + mul + add ( + optional activation )
360//
361
362inline bool IsSequenceLayerType(Layer& layer, LayerType type)
363{
364 return layer.GetType() == type;
365}
366
367inline bool IsSequenceLayerType(Layer& layer, BinaryOperation type)
368{
369 return (layer.GetType() == LayerType::ElementwiseBinary) &&
370 (PolymorphicDowncast<ElementwiseBinaryLayer*>(&layer)->GetParameters().m_Operation == type);
371}
372
373// Detect a layer sequence and activation if specified. The activation must be at the end of the sequence.
374template<typename TYPE>
375bool IsLayerSequence(Layer& currentLayer,
376 TYPE first,
377 TYPE second,
378 TYPE third,
379 Layer* layerList[4],
380 bool handleValidActivates,
381 const std::vector<ActivationFunction>& validActivates)
382{
383 auto PreviousLayer = [](Layer& layer)
384 {
385 return &layer.GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
386 };
387
388 auto NextLayer = [](Layer& layer)
389 {
390 return &layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
391 };
392
393 auto LayerIncomingConnectionDataType = [](Layer& layer)
394 {
395 return layer.GetInputSlot(0).GetTensorInfo().GetDataType();
396 };
397
398 bool result = false;
399
400 // Match in reverse so there is only 1 connection to check
401 if (IsSequenceLayerType(currentLayer, third))
402 {
403 // Save DataType of third layer
404 DataType dataType = LayerIncomingConnectionDataType(currentLayer);
405
406 // Save third layer
407 layerList[2] = &currentLayer;
408
409 // Check the layers that proceed this one for the requested grouping
410 Layer *prevLayer = PreviousLayer(currentLayer);
411 if (prevLayer && IsSequenceLayerType(*prevLayer, second))
412 {
413 bool dataTypesMatch = (dataType == LayerIncomingConnectionDataType(*prevLayer));
414 if (! dataTypesMatch)
415 {
416 return result;
417 }
418
419 layerList[1] = prevLayer;
420 prevLayer = PreviousLayer(*prevLayer);
421 if (prevLayer && IsSequenceLayerType(*prevLayer, first))
422 {
423 dataTypesMatch = (dataType == LayerIncomingConnectionDataType(*prevLayer));
424 if (! dataTypesMatch)
425 {
426 return result;
427 }
428
429 layerList[0] = prevLayer;
430
431 // Detected the first 3 layers if we get to this point so now
432 // check to see if we have a valid activation. If there is no activation
433 // then the sequence still matches.
434 if (handleValidActivates)
435 {
436 Layer *nextLayer = NextLayer(currentLayer);
437 if (nextLayer)
438 {
439 if (IsSequenceLayerType(*nextLayer, LayerType::Activation))
440 {
441 // This layer is an activation, so it must be a valid type for the sequence
442 ActivationFunction activationFunction =
443 PolymorphicDowncast<ActivationLayer*>(nextLayer)->GetParameters().m_Function;
444 long count = std::count(validActivates.cbegin(),
445 validActivates.cend(),
446 activationFunction);
447 if (count > 0)
448 {
449 layerList[3] = nextLayer;
450 result = true;
451 }
452 }
453 else
454 {
455 // Next layer is not an activation so sequence still matches
456 result = true;
457 }
458 }
459 }
460 else
461 {
462 result = true;
463 }
464 }
465 }
466 }
467
468 return result;
469}
470
Cathal Corbett3883b272022-07-22 16:03:36 +0100471} // namespace armnn