Tracy Narine | 6440ce8 | 2023-09-20 14:19:07 +0100 | [diff] [blame^] | 1 | // |
| 2 | // Copyright © 2023 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
| 8 | #include <aclCommon/ArmComputeSubgraphUtils.hpp> |
| 9 | |
| 10 | namespace armnn |
| 11 | { |
| 12 | |
| 13 | // Changes shapes of the form [1, 1, ..., W] to [ W ] |
| 14 | inline bool CollapseLeadingUnitDimensions(const TensorInfo& in, TensorInfo& out) |
| 15 | { |
| 16 | unsigned int numDimensions = in.GetNumDimensions(); |
| 17 | for (unsigned int i = 0; i < (numDimensions-1); ++i) |
| 18 | { |
| 19 | if (in.GetShape()[i] != 1) |
| 20 | { |
| 21 | return false; |
| 22 | } |
| 23 | } |
| 24 | |
| 25 | unsigned int w = in.GetShape()[numDimensions-1]; |
| 26 | out = in; |
| 27 | out.SetShape({w}); |
| 28 | |
| 29 | return true; |
| 30 | } |
| 31 | |
| 32 | // |
| 33 | // Build slot and tensor info lists for Add/Mul/Add replacement |
| 34 | // |
| 35 | template<typename SlotListType> |
| 36 | void BuildAddMulAddSlotLists(bool handleReLu, |
| 37 | bool multipleOutputs, |
| 38 | std::vector<SlotListType>& inputLayersSlotLists, |
| 39 | std::vector<SlotListType>& outputLayersSlotLists) |
| 40 | { |
| 41 | // Build input slot list |
| 42 | inputLayersSlotLists.push_back({0, 1}); // Add |
| 43 | inputLayersSlotLists.push_back({1}); // Mul |
| 44 | inputLayersSlotLists.push_back({1}); // Add |
| 45 | if (handleReLu) |
| 46 | { |
| 47 | inputLayersSlotLists.push_back({}); // Relu |
| 48 | } |
| 49 | |
| 50 | // Build output slot list |
| 51 | if (multipleOutputs) |
| 52 | { |
| 53 | outputLayersSlotLists.push_back({0}); // Add |
| 54 | } |
| 55 | else |
| 56 | { |
| 57 | outputLayersSlotLists.push_back({}); // Add |
| 58 | } |
| 59 | outputLayersSlotLists.push_back({}); // Mul |
| 60 | if (handleReLu) |
| 61 | { |
| 62 | outputLayersSlotLists.push_back({}); // Add |
| 63 | outputLayersSlotLists.push_back({0}); // Relu |
| 64 | } |
| 65 | else |
| 66 | { |
| 67 | outputLayersSlotLists.push_back({0}); // Add |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | inline void GetFusedName(Layer *layerList[4], std::string& fusedName) |
| 72 | { |
| 73 | // Build the fused name string |
| 74 | fusedName = "fused"; |
| 75 | for (unsigned int layerIdx = 0; layerIdx< 4; ++layerIdx) |
| 76 | { |
| 77 | if (! layerList[layerIdx]) |
| 78 | { |
| 79 | break; |
| 80 | } |
| 81 | fusedName += "-"; |
| 82 | fusedName += layerList[layerIdx]->GetNameStr(); |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | template<typename Type> |
| 87 | bool BuildAddMulAddTensorInfoLists(Type* layerList[4], |
| 88 | unsigned int& numInputs, |
| 89 | unsigned int& numOutputs, |
| 90 | std::vector<TensorInfo>& inputInfos, |
| 91 | std::vector<TensorInfo>& outputInfos, |
| 92 | const ActivationDescriptor*& activationDescriptor, |
| 93 | bool& fuseReLu) |
| 94 | { |
| 95 | ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[0]); |
| 96 | ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[1]); |
| 97 | ARMNN_THROW_INVALIDARG_IF_FALSE(layerList[2]); |
| 98 | |
| 99 | ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[0], BinaryOperation::Add)); |
| 100 | ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[1], BinaryOperation::Mul)); |
| 101 | ARMNN_THROW_INVALIDARG_IF_FALSE(IsSequenceLayerType(*layerList[2], BinaryOperation::Add)); |
| 102 | |
| 103 | fuseReLu = (layerList[3] != nullptr); |
| 104 | if (fuseReLu) |
| 105 | { |
| 106 | activationDescriptor = &PolymorphicDowncast<ActivationLayer *>(layerList[3])->GetParameters(); |
| 107 | ARMNN_THROW_INVALIDARG_IF_FALSE((activationDescriptor->m_Function == ActivationFunction::ReLu) || |
| 108 | (activationDescriptor->m_Function == ActivationFunction::BoundedReLu)); |
| 109 | } |
| 110 | |
| 111 | numInputs = 0; |
| 112 | numOutputs = 0; |
| 113 | |
| 114 | // Ensure that there are 6 input slots in the add/mul/add layers |
| 115 | // we are going to replace |
| 116 | unsigned int layerIdx = 0; |
| 117 | unsigned int inputSlotCount = 0; |
| 118 | for (layerIdx = 0; layerIdx < 3; ++layerIdx) |
| 119 | { |
| 120 | for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumInputSlots(); ++slotIdx) |
| 121 | { |
| 122 | InputSlot* inputSlot = &layerList[layerIdx]->GetInputSlot(slotIdx); |
| 123 | OutputSlot* outputSlot = inputSlot->GetConnectedOutputSlot(); |
| 124 | if (outputSlot) |
| 125 | { |
| 126 | if (layerIdx == 0) |
| 127 | { |
| 128 | // Always count the input connections of the first add |
| 129 | inputInfos.push_back(inputSlot->GetTensorInfo()); |
| 130 | numInputs++; |
| 131 | } |
| 132 | else |
| 133 | { |
| 134 | // For subsequent layers, we skip connections to the previous layers in the counting |
| 135 | if (&outputSlot->GetOwningLayer() != layerList[layerIdx-1]) |
| 136 | { |
| 137 | TensorInfo inputSlotInfo = inputSlot->GetTensorInfo(); |
| 138 | if (numInputs == 2 || numInputs == 3) |
| 139 | { |
| 140 | // Workaround the broadcast optimization to collapse shapes such as |
| 141 | // [1, 1, 1, 2] to [2] as required by backend |
| 142 | if (CollapseLeadingUnitDimensions(inputSlot->GetTensorInfo(), inputSlotInfo)) |
| 143 | { |
| 144 | OutputSlot* previousLayerSlot = inputSlot->GetConnectedOutputSlot(); |
| 145 | if (previousLayerSlot) |
| 146 | { |
| 147 | if (previousLayerSlot->GetOwningLayer().GetType() == LayerType::Constant) |
| 148 | { |
| 149 | // First update the TensorInfo in the constant owning layer |
| 150 | previousLayerSlot->SetTensorInfo(inputSlotInfo); |
| 151 | // Then update the TensorInfo in the workload for the owning layer |
| 152 | ConstantLayer* layer = PolymorphicDowncast<ConstantLayer*>( |
| 153 | &previousLayerSlot->GetOwningLayer()); |
| 154 | layer->m_LayerOutput |
| 155 | = std::make_unique<ScopedTensorHandle>( |
| 156 | ConstTensor(inputSlotInfo, |
| 157 | layer->m_LayerOutput.get()->GetConstTensor<void>())); |
| 158 | } |
| 159 | } |
| 160 | } |
| 161 | } |
| 162 | inputInfos.push_back(inputSlotInfo); |
| 163 | numInputs++; |
| 164 | } |
| 165 | } |
| 166 | inputSlotCount++; |
| 167 | } |
| 168 | } |
| 169 | } |
| 170 | |
| 171 | // Check the input counts |
| 172 | bool validInputCount = (inputSlotCount == 6) && (inputInfos.size() == 4); |
| 173 | if (! validInputCount) |
| 174 | { |
| 175 | return false; |
| 176 | } |
| 177 | |
| 178 | const unsigned int maxIdx = (fuseReLu) ? 4 : 3; |
| 179 | for (layerIdx = 0; layerIdx < maxIdx; ++layerIdx) |
| 180 | { |
| 181 | for (unsigned int slotIdx = 0; slotIdx < layerList[layerIdx]->GetNumOutputSlots(); ++slotIdx) |
| 182 | { |
| 183 | OutputSlot* outputSlot = &layerList[layerIdx]->GetOutputSlot(slotIdx); |
| 184 | |
| 185 | for (unsigned int connectionIdx = 0; connectionIdx < outputSlot->GetNumConnections(); ++connectionIdx) |
| 186 | { |
| 187 | InputSlot* inputSlot = outputSlot->GetConnection(connectionIdx); |
| 188 | if (layerIdx < (maxIdx-1)) |
| 189 | { |
| 190 | if (&inputSlot->GetOwningLayer() != layerList[layerIdx+1]) |
| 191 | { |
| 192 | outputInfos.push_back(outputSlot->GetTensorInfo()); |
| 193 | numOutputs++; |
| 194 | } |
| 195 | } |
| 196 | else if (layerList[layerIdx] != nullptr) |
| 197 | { |
| 198 | outputInfos.push_back(outputSlot->GetTensorInfo()); |
| 199 | numOutputs++; |
| 200 | } |
| 201 | } |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | // Check the output count |
| 206 | bool validOutputCount = (outputInfos.size() > 0); |
| 207 | if (! validOutputCount) |
| 208 | { |
| 209 | return false; |
| 210 | } |
| 211 | |
| 212 | return true; |
| 213 | } |
| 214 | |
| 215 | } |