blob: ddd02abede1a95e0d93d3f2b17629ddd743b5905 [file] [log] [blame]
Kevin May43a799c2019-02-08 16:31:42 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "armnn/INetwork.hpp"
9#include "armnnDeserializeParser/IDeserializeParser.hpp"
10#include <Schema_generated.h>
11
12namespace armnnDeserializeParser
13{
14class DeserializeParser : public IDeserializeParser
15{
16public:
17 // Shorthands for deserializer types
18 using GraphPtr = const armnn::armnnSerializer::SerializedGraph *;
19 using TensorRawPtr = const armnn::armnnSerializer::TensorInfo *;
20 using TensorRawPtrVector = std::vector<TensorRawPtr>;
21 using LayerRawPtr = const armnn::armnnSerializer::LayerBase *;
22 using LayerBaseRawPtr = const armnn::armnnSerializer::LayerBase *;
23 using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
24
25public:
26
27 /// Create the network from a flatbuffers binary file on disk
28 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
29
30 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
31
32 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
33 virtual BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId,
34 const std::string& name) const override;
35
36 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
37 virtual BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId,
38 const std::string& name) const override;
39
40 DeserializeParser();
41 ~DeserializeParser() {}
42
43public:
44 // testable helpers
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000045 static GraphPtr LoadGraphFromFile(const char* fileName, std::string& fileContent);
Kevin May43a799c2019-02-08 16:31:42 +000046 static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
47 static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
48 static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
49 static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr);
50 static LayerBaseRawPtrVector GetGraphOutputs(const GraphPtr& graphPtr);
51 static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
52 static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
53
54private:
55 // No copying allowed until it is wanted and properly implemented
56 DeserializeParser(const DeserializeParser&) = delete;
57 DeserializeParser& operator=(const DeserializeParser&) = delete;
58
59 /// Create the network from an already loaded flatbuffers graph
60 armnn::INetworkPtr CreateNetworkFromGraph();
61
62 // signature for the parser functions
63 using LayerParsingFunction = void(DeserializeParser::*)(unsigned int layerIndex);
64
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000065 void ParseUnsupportedLayer(unsigned int layerIndex);
66 void ParseAdd(unsigned int layerIndex);
67 void ParseMultiplication(unsigned int layerIndex);
68 void ParseSoftmax(unsigned int layerIndex);
Kevin May43a799c2019-02-08 16:31:42 +000069
70 void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
71 void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
72 void RegisterInputSlots(uint32_t layerIndex,
73 armnn::IConnectableLayer* layer);
74 void RegisterOutputSlots(uint32_t layerIndex,
75 armnn::IConnectableLayer* layer);
76 void ResetParser();
77
78 void SetupInputLayers();
79 void SetupOutputLayers();
80
81 /// The network we're building. Gets cleared after it is passed to the user
82 armnn::INetworkPtr m_Network;
83 GraphPtr m_Graph;
84 std::vector<LayerParsingFunction> m_ParserFunctions;
85
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000086 /// This holds the data of the file that was read in from CreateNetworkFromBinaryFile
87 /// Needed for m_Graph to point to
88 std::string m_FileContent;
89
Kevin May43a799c2019-02-08 16:31:42 +000090 /// A mapping of an output slot to each of the input slots it should be connected to
91 /// The outputSlot is from the layer that creates this tensor as one of its outputs
92 /// The inputSlots are from the layers that use this tensor as one of their inputs
93 struct Slots
94 {
95 armnn::IOutputSlot* outputSlot;
96 std::vector<armnn::IInputSlot*> inputSlots;
97
98 Slots() : outputSlot(nullptr) { }
99 };
100 typedef std::vector<Slots> Connection;
101 std::vector<Connection> m_GraphConnections;
102};
103
104}