IVGCVSW-2783 Fix Deserializer connections for layer with multiple outputs
Change-Id: Icb278dfd8900334665432963fa6f6341a461ef3b
Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 4d9c138..e837a08 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -9,6 +9,8 @@
#include "armnnDeserializer/IDeserializer.hpp"
#include <ArmnnSchema_generated.h>
+#include <unordered_map>
+
namespace armnnDeserializer
{
class Deserializer : public IDeserializer
@@ -100,8 +102,8 @@
void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
- void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
- void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
+ void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot);
+ void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
armnn::IConnectableLayer* layer);
void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
@@ -120,17 +122,14 @@
std::vector<NameToBindingInfo> m_OutputBindings;
/// A mapping of an output slot to each of the input slots it should be connected to
- /// The outputSlot is from the layer that creates this tensor as one of its outputs
- /// The inputSlots are from the layers that use this tensor as one of their inputs
- struct Slots
+ struct SlotsMap
{
- armnn::IOutputSlot* outputSlot;
- std::vector<armnn::IInputSlot*> inputSlots;
-
- Slots() : outputSlot(nullptr) { }
+ std::vector<armnn::IOutputSlot*> outputSlots;
+ std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
};
- typedef std::vector<Slots> Connection;
- std::vector<Connection> m_GraphConnections;
+
+ typedef std::vector<SlotsMap> Connections;
+ std::vector<Connections> m_GraphConnections;
};
} //namespace armnnDeserializer