blob: a0fca4633077668bf840a6b2c6e634665fc975b0 [file] [log] [blame]
Mike Kelly07810fc2020-11-12 10:58:48 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/backends/OptimizationViews.hpp>
9
10namespace armnn
11{
12
13namespace
14{
15
16//
17// this helper only works if all layers where the inputs connect to are not selected
18//
19SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers)
20{
21 SubgraphView::InputSlots result;
22 for (auto&& layer : layers)
23 {
24 for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
25 {
26 result.push_back(&(*it));
27 }
28 }
29 return result;
30}
31
32//
33// this helper only works if all layers where the outputs connect to are not selected
34//
35SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
36{
37 SubgraphView::OutputSlots result;
38 for (auto&& layer : layers)
39 {
40 for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
41 {
42 result.push_back(&(*it));
43 }
44 }
45 return result;
46}
47
Teresa Charlind672f5d2021-01-18 18:07:57 +000048bool checkDataTypeInputandOutput(const Layer& layer)
49{
50 auto inputInfo = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
51 auto outputInfo = layer.GetOutputSlot(0).GetTensorInfo();
52 bool sameDataType = (inputInfo.GetDataType() == outputInfo.GetDataType());
53
54 // Check is same quantization info (same scale and offset)
55 if (sameDataType)
56 {
57 if (IsQuantizedType(inputInfo.GetDataType()))
58 {
59 bool sameScale = (inputInfo.GetQuantizationScale() == outputInfo.GetQuantizationScale());
60 bool sameOffset = (inputInfo.GetQuantizationOffset() == outputInfo.GetQuantizationOffset());
61
62 return (sameScale && sameOffset);
63 }
64 else
65 {
66 return true;
67 }
68 }
69 else
70 {
71 return false;
72 }
73}
74
Mike Kelly07810fc2020-11-12 10:58:48 +000075} // namespace
76
Mike Kelly1ac690a2020-11-17 11:41:38 +000077inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
Mike Kelly07810fc2020-11-12 10:58:48 +000078{
Mike Kelly1ac690a2020-11-17 11:41:38 +000079 std::vector<Layer*> untouchedVector;
80 for (const auto& pair : untouched)
Mike Kelly07810fc2020-11-12 10:58:48 +000081 {
Mike Kelly1ac690a2020-11-17 11:41:38 +000082 Layer* layer = pair.second;
83 SubgraphView subgraphView(CreateInputsFrom({layer}),
84 CreateOutputsFrom({layer}),
85 {layer});
86 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
Mike Kelly07810fc2020-11-12 10:58:48 +000087 }
Mike Kelly07810fc2020-11-12 10:58:48 +000088}
89
90template<typename LayerType>
91LayerType* FuseLayerWithoutParameters(OptimizationViews& optimizationViews,
92 LayerType* baseLayer,
93 ActivationLayer* activationLayer,
94 ActivationDescriptor& activationDesc,
95 std::string name)
96{
97 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(name.c_str());
98
99 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
100
101 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
102 CreateOutputsFrom({activationLayer}),
103 {baseLayer, activationLayer});
104 SubgraphView replacementSubgraph(replacementLayer);
105
106 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
107 return replacementLayer;
108}
109
110template<typename LayerType>
111LayerType* FuseLayerWithParameters(OptimizationViews& optimizationViews,
112 LayerType* baseLayer,
113 ActivationLayer* activationLayer,
114 ActivationDescriptor& activationDesc,
115 std::string name)
116{
117 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(baseLayer->GetParameters(),
118 name.c_str());
119
120 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
121
122 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
123 CreateOutputsFrom({activationLayer}),
124 {baseLayer, activationLayer});
125 SubgraphView replacementSubgraph(replacementLayer);
126
127 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
128 return replacementLayer;
129}
130
131template<typename LayerType>
132LayerType* FuseLayerWithWeightsAndBiases(OptimizationViews& optimizationViews,
133 LayerType* baseLayer,
134 ActivationLayer* activationLayer,
135 ActivationDescriptor& activationDesc,
136 std::string name)
137{
138 LayerType* replacementLayer = FuseLayerWithParameters(optimizationViews,
139 baseLayer,
140 activationLayer,
141 activationDesc,
142 name);
143
144 replacementLayer->m_Weight = std::move(baseLayer->m_Weight);
145 replacementLayer->m_Bias = std::move(baseLayer->m_Bias);
146
147 return replacementLayer;
148}
149
150} // namespace armnn