blob: bd3d698a98213980b954f833934021c4614012e9 [file] [log] [blame]
Cathal Corbettb7e5f532022-07-22 16:03:36 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <optimizations/FoldPadIntoLayer2d.hpp>
9
10namespace armnn
11{
12
13namespace
14{
15
16//
17// this helper only works if all layers where the inputs connect to are not selected
18//
19
20SubgraphView::IInputSlots CreateIInputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
21{
22 SubgraphView::IInputSlots result;
23 for (auto&& layer : layers)
24 {
25 for (unsigned int i = 0 ; i < layer->GetNumInputSlots(); ++i)
26 {
27 result.push_back(&(layer->GetInputSlot(i)));
28 }
29 }
30 return result;
31}
32
33//
34// this helper only works if all layers where the outputs connect to are not selected
35//
36
37SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
38{
39 SubgraphView::IOutputSlots result;
40 for (auto &&layer: layers)
41 {
42 for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
43 {
44 result.push_back(&(layer->GetOutputSlot(i)));
45 }
46 }
47 return result;
48}
49
50}
51
52inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
53{
54 std::vector<Layer*> untouchedVector;
55 for (const auto& pair : untouched)
56 {
57 Layer* layer = pair.second;
58 SubgraphView subgraphView({layer},
59 CreateIInputsFrom({layer}),
60 CreateIOutputsFrom({layer}));
61 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
62 }
63}
64
65template<typename LayerType>
66LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
67 LayerType* baseLayer,
68 LayerType* replacementLayer,
69 PadLayer* padLayer)
70{
71 SubgraphView substitutionSubgraph({padLayer, baseLayer},
72 CreateIInputsFrom({padLayer}),
73 CreateIOutputsFrom({baseLayer}));
74 SubgraphView replacementSubgraph(replacementLayer);
75
76 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
77
78 return replacementLayer;
79}
80
81template<typename LayerType>
82LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
83 Pooling2dLayer* baseLayer,
84 Pooling2dDescriptor& poolDescriptor,
85 PadLayer* padLayer)
86{
87 IConnectableLayer* replacement =
88 optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
89 LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
90
91 FoldPadLayer(optimizationViews,
92 baseLayer,
93 replacementLayer,
94 padLayer);
95
96 return replacementLayer;
97}
98
99} // namespace armnn