surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 4 | // |
| 5 | #include "MergerLayer.hpp" |
| 6 | #include "LayerCloneBase.hpp" |
| 7 | |
| 8 | #include <armnn/TypesUtils.hpp> |
Aron Virginas-Tar | c9cc804 | 2018-11-01 16:15:57 +0000 | [diff] [blame] | 9 | #include <backendsCommon/WorkloadData.hpp> |
| 10 | #include <backendsCommon/WorkloadFactory.hpp> |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 11 | |
| 12 | #include <queue> |
| 13 | |
| 14 | namespace armnn |
| 15 | { |
| 16 | |
| 17 | MergerLayer::MergerLayer(const OriginsDescriptor& param, const char* name) |
| 18 | : LayerWithParameters(param.GetNumViews(), 1, LayerType::Merger, param, name) |
| 19 | { |
| 20 | } |
| 21 | |
| 22 | std::unique_ptr<IWorkload> MergerLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const |
| 23 | { |
| 24 | MergerQueueDescriptor descriptor; |
| 25 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 26 | // Copies the view origins to the descriptor. |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 27 | descriptor.m_ViewOrigins.reserve(m_Param.GetNumViews()); |
| 28 | for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i) |
| 29 | { |
| 30 | descriptor.m_ViewOrigins.emplace_back( |
| 31 | std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions())); |
| 32 | } |
| 33 | |
| 34 | return factory.CreateMerger(descriptor, PrepInfoAndDesc(descriptor, graph)); |
| 35 | } |
| 36 | |
| 37 | void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) |
| 38 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 39 | //If sub tensors are supported than the merger |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 40 | //just needs to make sure that the outputs of the prev layer |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 41 | //are made subtensors of the output of the merger layer. |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 42 | m_OutputHandlers[0].CreateTensorHandles(factory); |
| 43 | if (factory.SupportsSubTensors()) |
| 44 | { |
| 45 | std::queue<MergerLayer*> m_MergerLayers; |
| 46 | |
| 47 | m_MergerLayers.push(this); |
| 48 | while (!m_MergerLayers.empty()) |
| 49 | { |
| 50 | MergerLayer* currentLayer = m_MergerLayers.front(); |
| 51 | ITensorHandle* parentTensor = currentLayer->GetOutputHandler(0).GetData(); |
| 52 | |
| 53 | m_MergerLayers.pop(); |
| 54 | |
| 55 | const unsigned int numInputSlots = currentLayer->GetNumInputSlots(); |
| 56 | for (unsigned int i = 0; i < numInputSlots; ++i) |
| 57 | { |
| 58 | OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot(); |
| 59 | OutputHandler& outputHandler = slot->GetOutputHandler(); |
| 60 | outputHandler.SetData(factory.CreateSubTensorHandle(*parentTensor, |
| 61 | outputHandler.GetTensorInfo().GetShape(), |
| 62 | currentLayer->m_Param.GetViewOrigin(i))); |
| 63 | |
| 64 | Layer& inputLayer = slot->GetOwningLayer(); |
| 65 | if (inputLayer.GetType() == LayerType::Merger) |
| 66 | { |
| 67 | m_MergerLayers.push(boost::polymorphic_downcast<MergerLayer*>(&inputLayer)); |
| 68 | } |
| 69 | } |
| 70 | } |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | MergerLayer* MergerLayer::Clone(Graph& graph) const |
| 75 | { |
| 76 | return CloneBase<MergerLayer>(graph, m_Param, GetName()); |
| 77 | } |
| 78 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 79 | std::vector<TensorShape> MergerLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 80 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 81 | BOOST_ASSERT(inputShapes.size() == m_Param.GetNumViews()); |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 82 | |
| 83 | unsigned int numDims = m_Param.GetNumDimensions(); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 84 | for (unsigned int i=0; i< inputShapes.size(); i++) |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 85 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 86 | auto& inputShape = inputShapes[i]; |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 87 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 88 | ConditionalThrowIfNotEqual<LayerValidationException>( |
| 89 | "MergerLayer: Num Dimensions must match all inputs.", |
| 90 | numDims, |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 91 | inputShape.GetNumDimensions()); |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 92 | } |
| 93 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 94 | // Finds the bounding box (extents) of all the views. |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 95 | std::vector<unsigned int> extentMin(numDims); |
| 96 | std::vector<unsigned int> extentMax(numDims); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 97 | for (unsigned int i = 0; i < inputShapes.size(); i++) |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 98 | { |
| 99 | const uint32_t* origin = m_Param.GetViewOrigin(i); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 100 | const armnn::TensorShape& shape = inputShapes[i]; |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 101 | for (unsigned int d = 0; d < numDims; d++) |
| 102 | { |
| 103 | extentMin[d] = std::min(extentMin[d], origin[d]); |
| 104 | extentMax[d] = std::max(extentMax[d], origin[d] + shape[d]); |
| 105 | } |
| 106 | } |
| 107 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 108 | // Checks that the bounding box starts at the origin. |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 109 | if (!std::all_of(extentMin.begin(), extentMin.end(), [](unsigned int s) { return s == 0; })) |
| 110 | { |
| 111 | throw LayerValidationException("MergerLayer: there is no view that starts at the origin"); |
| 112 | } |
| 113 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 114 | // Checks that there are no overlaps of views (this would lead to undefined output at those locations). |
| 115 | // Checks each pair of views against each other |
| 116 | // (and doesn't bother to check against self, or check the same pair both ways round). |
| 117 | for (unsigned int a = 0; a < inputShapes.size(); a++) |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 118 | { |
| 119 | const uint32_t* aOrigin = m_Param.GetViewOrigin(a); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 120 | const armnn::TensorShape& aShape = inputShapes[a]; |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 121 | for (unsigned int b = 0; b < a; b++) |
| 122 | { |
| 123 | const uint32_t* bOrigin = m_Param.GetViewOrigin(b); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 124 | const armnn::TensorShape& bShape = inputShapes[b]; |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 125 | |
| 126 | bool allAxesOverlap = true; |
| 127 | for (unsigned int d = 0; d < numDims && allAxesOverlap; d++) |
| 128 | { |
| 129 | unsigned int a1 = aOrigin[d]; |
| 130 | unsigned int a2 = aOrigin[d] + aShape[d]; |
| 131 | |
| 132 | unsigned int b1 = bOrigin[d]; |
| 133 | unsigned int b2 = bOrigin[d] + bShape[d]; |
| 134 | |
| 135 | if (a2 <= b1 || b2 <= a1) |
| 136 | { |
| 137 | allAxesOverlap = false; |
| 138 | } |
| 139 | } |
| 140 | if (allAxesOverlap) |
| 141 | { |
| 142 | throw LayerValidationException("MergerLayer: Some views overlap."); |
| 143 | } |
| 144 | } |
| 145 | } |
| 146 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 147 | // Checks that there are no "holes", i.e. regions of the output which is not covered by a view. |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 148 | // Because we already checked that there are no overlaps, this can be done simply by checking that |
| 149 | // the total 'volume' of the views is the same as the output. |
| 150 | unsigned int totalViewsVolume = 0; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 151 | for (unsigned int i = 0; i < inputShapes.size(); i++) |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 152 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 153 | totalViewsVolume += inputShapes[i].GetNumElements(); |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 154 | } |
| 155 | unsigned int outputVolume = 1; |
| 156 | for (unsigned int d = 0; d < numDims; d++) |
| 157 | { |
| 158 | outputVolume *= (extentMax[d] - extentMin[d]); |
| 159 | } |
| 160 | |
| 161 | ConditionalThrowIfNotEqual<LayerValidationException>( |
| 162 | "MergerLayer: there are some gaps between views", |
| 163 | totalViewsVolume, |
| 164 | outputVolume); |
| 165 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 166 | return std::vector<TensorShape>({ TensorShape({numDims, extentMax.data()}) }); |
| 167 | } |
| 168 | |
| 169 | void MergerLayer::ValidateTensorShapesFromInputs() |
| 170 | { |
| 171 | // Validates Merger layer. |
| 172 | ConditionalThrowIfNotEqual<LayerValidationException>( |
| 173 | "MergerLayer: Num Inputs must match num views.", |
| 174 | m_Param.GetNumViews(), |
| 175 | GetNumInputSlots()); |
| 176 | |
| 177 | VerifyLayerConnections(m_Param.GetNumViews(), CHECK_LOCATION()); |
| 178 | |
| 179 | std::vector<TensorShape> inputShapes; |
| 180 | for (uint i = 0; i < GetNumInputSlots(); ++i) |
| 181 | { |
| 182 | inputShapes.push_back(GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape()); |
| 183 | } |
| 184 | |
| 185 | auto inferredShapes = InferOutputShapes(inputShapes); |
| 186 | |
| 187 | BOOST_ASSERT(inferredShapes.size() == 1); |
| 188 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 189 | ConditionalThrowIfNotEqual<LayerValidationException>( |
| 190 | "MergerLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", |
| 191 | GetOutputSlot(0).GetTensorInfo().GetShape(), |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 192 | inferredShapes[0]); |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 193 | } |
| 194 | |
| 195 | } // namespace armnn armnn |