blob: 3a8bf465994ada17f188a872409731ea1eb712e7 [file] [log] [blame]
Tracy Narine6440ce82023-09-20 14:19:07 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeSubgraphUtils.hpp>
9
10namespace armnn
11{
12
13// Changes shapes of the form [1, 1, ..., W] to [ W ]
14inline bool CollapseLeadingUnitDimensions(const TensorInfo& in, TensorInfo& out)
15{
16 unsigned int numDimensions = in.GetNumDimensions();
17 for (unsigned int i = 0; i < (numDimensions-1); ++i)
18 {
19 if (in.GetShape()[i] != 1)
20 {
21 return false;
22 }
23 }
24
25 unsigned int w = in.GetShape()[numDimensions-1];
26 out = in;
27 out.SetShape({w});
28
29 return true;
30}
31
32//
33// Build slot and tensor info lists for Add/Mul/Add replacement
34//
35template<typename SlotListType>
36void BuildAddMulAddSlotLists(bool handleReLu,
37 bool multipleOutputs,
38 std::vector<SlotListType>& inputLayersSlotLists,
39 std::vector<SlotListType>& outputLayersSlotLists)
40{
41 // Build input slot list
42 inputLayersSlotLists.push_back({0, 1}); // Add
43 inputLayersSlotLists.push_back({1}); // Mul
44 inputLayersSlotLists.push_back({1}); // Add
45 if (handleReLu)
46 {
47 inputLayersSlotLists.push_back({}); // Relu
48 }
49
50 // Build output slot list
51 if (multipleOutputs)
52 {
53 outputLayersSlotLists.push_back({0}); // Add
54 }
55 else
56 {
57 outputLayersSlotLists.push_back({}); // Add
58 }
59 outputLayersSlotLists.push_back({}); // Mul
60 if (handleReLu)
61 {
62 outputLayersSlotLists.push_back({}); // Add
63 outputLayersSlotLists.push_back({0}); // Relu
64 }
65 else
66 {
67 outputLayersSlotLists.push_back({0}); // Add
68 }
69}
70
71inline void GetFusedName(Layer *layerList[4], std::string& fusedName)
72{
73 // Build the fused name string
74 fusedName = "fused";
75 for (unsigned int layerIdx = 0; layerIdx< 4; ++layerIdx)
76 {
77 if (! layerList[layerIdx])
78 {
79 break;
80 }
81 fusedName += "-";
82 fusedName += layerList[layerIdx]->GetNameStr();
83 }
84}
85
86template<typename Type>
87bool BuildAddMulAddTensorInfoLists(Type* layerList[4],
88 unsigned int& numInputs,
89 unsigned int& numOutputs,
90 std::vector<TensorInfo>& inputInfos,
91 std::vector<TensorInfo>& outputInfos,
92 const ActivationDescriptor*& activationDescriptor,
93 bool& fuseReLu)
94{
95 ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[0]);
96 ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[1]);
97 ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[2]);
98
99 ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[0], BinaryOperation::Add));
100 ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[1], BinaryOperation::Mul));
101 ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[2], BinaryOperation::Add));
102
103 fuseReLu = (layerList[3] != nullptr);
104 if (fuseReLu)
105 {
106 activationDescriptor = &PolymorphicDowncast<ActivationLayer *>(layerList[3])->GetParameters();
107 ARMNN_THROW_INVALIDARG_IF_FALSE((activationDescriptor->m_Function == ActivationFunction::ReLu) ||
108 (activationDescriptor->m_Function == ActivationFunction::BoundedReLu));
109 }
110
111 numInputs = 0;
112 numOutputs = 0;
113
114 // Ensure that there are 6 input slots in the add/mul/add layers
115 // we are going to replace
116 unsigned int layerIdx = 0;
117 unsigned int inputSlotCount = 0;
118 for (layerIdx = 0; layerIdx < 3; ++layerIdx)
119 {
120 for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumInputSlots(); ++slotIdx)
121 {
122 InputSlot* inputSlot = &layerList[layerIdx]->GetInputSlot(slotIdx);
123 OutputSlot* outputSlot = inputSlot->GetConnectedOutputSlot();
124 if (outputSlot)
125 {
126 if (layerIdx == 0)
127 {
128 // Always count the input connections of the first add
129 inputInfos.push_back(inputSlot->GetTensorInfo());
130 numInputs++;
131 }
132 else
133 {
134 // For subsequent layers, we skip connections to the previous layers in the counting
135 if (&outputSlot->GetOwningLayer() != layerList[layerIdx-1])
136 {
137 TensorInfo inputSlotInfo = inputSlot->GetTensorInfo();
138 if (numInputs == 2 || numInputs == 3)
139 {
140 // Workaround the broadcast optimization to collapse shapes such as
141 // [1, 1, 1, 2] to [2] as required by backend
142 if (CollapseLeadingUnitDimensions(inputSlot->GetTensorInfo(), inputSlotInfo))
143 {
144 OutputSlot* previousLayerSlot = inputSlot->GetConnectedOutputSlot();
145 if (previousLayerSlot)
146 {
147 if (previousLayerSlot->GetOwningLayer().GetType() == LayerType::Constant)
148 {
149 // First update the TensorInfo in the constant owning layer
150 previousLayerSlot->SetTensorInfo(inputSlotInfo);
151 // Then update the TensorInfo in the workload for the owning layer
152 ConstantLayer* layer = PolymorphicDowncast<ConstantLayer*>(
153 &previousLayerSlot->GetOwningLayer());
154 layer->m_LayerOutput
155 = std::make_unique<ScopedTensorHandle>(
156 ConstTensor(inputSlotInfo,
157 layer->m_LayerOutput.get()->GetConstTensor<void>()));
158 }
159 }
160 }
161 }
162 inputInfos.push_back(inputSlotInfo);
163 numInputs++;
164 }
165 }
166 inputSlotCount++;
167 }
168 }
169 }
170
171 // Check the input counts
172 bool validInputCount = (inputSlotCount == 6) && (inputInfos.size() == 4);
173 if (! validInputCount)
174 {
175 return false;
176 }
177
178 const unsigned int maxIdx = (fuseReLu) ? 4 : 3;
179 for (layerIdx = 0; layerIdx < maxIdx; ++layerIdx)
180 {
181 for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumOutputSlots(); ++slotIdx)
182 {
183 OutputSlot* outputSlot = &layerList[layerIdx]->GetOutputSlot(slotIdx);
184
185 for (unsigned int connectionIdx = 0; connectionIdx < outputSlot->GetNumConnections(); ++connectionIdx)
186 {
187 InputSlot* inputSlot = outputSlot->GetConnection(connectionIdx);
188 if (layerIdx < (maxIdx-1))
189 {
190 if (&inputSlot->GetOwningLayer() != layerList[layerIdx+1])
191 {
192 outputInfos.push_back(outputSlot->GetTensorInfo());
193 numOutputs++;
194 }
195 }
196 else if (layerList[layerIdx] != nullptr)
197 {
198 outputInfos.push_back(outputSlot->GetTensorInfo());
199 numOutputs++;
200 }
201 }
202 }
203 }
204
205 // Check the output count
206 bool validOutputCount = (outputInfos.size() > 0);
207 if (! validOutputCount)
208 {
209 return false;
210 }
211
212 return true;
213}
214
215}