blob: 860d88df8041e8fd5c3ad3c0ddf0f9b79fd8db2c [file] [log] [blame]
Mike Kelly07810fc2020-11-12 10:58:48 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/backends/OptimizationViews.hpp>
9
10namespace armnn
11{
12
13namespace
14{
15
16//
17// this helper only works if all layers where the inputs connect to are not selected
18//
19SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers)
20{
21 SubgraphView::InputSlots result;
22 for (auto&& layer : layers)
23 {
24 for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
25 {
26 result.push_back(&(*it));
27 }
28 }
29 return result;
30}
31
32//
33// this helper only works if all layers where the outputs connect to are not selected
34//
35SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
36{
37 SubgraphView::OutputSlots result;
38 for (auto&& layer : layers)
39 {
40 for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
41 {
42 result.push_back(&(*it));
43 }
44 }
45 return result;
46}
47
48} // namespace
49
Mike Kelly1ac690a2020-11-17 11:41:38 +000050inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
Mike Kelly07810fc2020-11-12 10:58:48 +000051{
Mike Kelly1ac690a2020-11-17 11:41:38 +000052 std::vector<Layer*> untouchedVector;
53 for (const auto& pair : untouched)
Mike Kelly07810fc2020-11-12 10:58:48 +000054 {
Mike Kelly1ac690a2020-11-17 11:41:38 +000055 Layer* layer = pair.second;
56 SubgraphView subgraphView(CreateInputsFrom({layer}),
57 CreateOutputsFrom({layer}),
58 {layer});
59 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
Mike Kelly07810fc2020-11-12 10:58:48 +000060 }
Mike Kelly07810fc2020-11-12 10:58:48 +000061}
62
63template<typename LayerType>
64LayerType* FuseLayerWithoutParameters(OptimizationViews& optimizationViews,
65 LayerType* baseLayer,
66 ActivationLayer* activationLayer,
67 ActivationDescriptor& activationDesc,
68 std::string name)
69{
70 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(name.c_str());
71
72 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
73
74 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
75 CreateOutputsFrom({activationLayer}),
76 {baseLayer, activationLayer});
77 SubgraphView replacementSubgraph(replacementLayer);
78
79 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
80 return replacementLayer;
81}
82
83template<typename LayerType>
84LayerType* FuseLayerWithParameters(OptimizationViews& optimizationViews,
85 LayerType* baseLayer,
86 ActivationLayer* activationLayer,
87 ActivationDescriptor& activationDesc,
88 std::string name)
89{
90 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(baseLayer->GetParameters(),
91 name.c_str());
92
93 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
94
95 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
96 CreateOutputsFrom({activationLayer}),
97 {baseLayer, activationLayer});
98 SubgraphView replacementSubgraph(replacementLayer);
99
100 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
101 return replacementLayer;
102}
103
104template<typename LayerType>
105LayerType* FuseLayerWithWeightsAndBiases(OptimizationViews& optimizationViews,
106 LayerType* baseLayer,
107 ActivationLayer* activationLayer,
108 ActivationDescriptor& activationDesc,
109 std::string name)
110{
111 LayerType* replacementLayer = FuseLayerWithParameters(optimizationViews,
112 baseLayer,
113 activationLayer,
114 activationDesc,
115 name);
116
117 replacementLayer->m_Weight = std::move(baseLayer->m_Weight);
118 replacementLayer->m_Bias = std::move(baseLayer->m_Bias);
119
120 return replacementLayer;
121}
122
123} // namespace armnn