blob: 1edb5a9f2391eb4ed307029442c4c2b602f3195f [file] [log] [blame]
Kevin May43a799c2019-02-08 16:31:42 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "armnn/INetwork.hpp"
9#include "armnnDeserializeParser/IDeserializeParser.hpp"
10#include <Schema_generated.h>
11
12namespace armnnDeserializeParser
13{
14class DeserializeParser : public IDeserializeParser
15{
16public:
17 // Shorthands for deserializer types
18 using GraphPtr = const armnn::armnnSerializer::SerializedGraph *;
19 using TensorRawPtr = const armnn::armnnSerializer::TensorInfo *;
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000020 using PoolingDescriptor = const armnn::armnnSerializer::Pooling2dDescriptor *;
Kevin May43a799c2019-02-08 16:31:42 +000021 using TensorRawPtrVector = std::vector<TensorRawPtr>;
22 using LayerRawPtr = const armnn::armnnSerializer::LayerBase *;
23 using LayerBaseRawPtr = const armnn::armnnSerializer::LayerBase *;
24 using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
25
26public:
27
28 /// Create the network from a flatbuffers binary file on disk
29 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
30
31 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
32
33 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
34 virtual BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId,
35 const std::string& name) const override;
36
37 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
38 virtual BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId,
39 const std::string& name) const override;
40
41 DeserializeParser();
42 ~DeserializeParser() {}
43
44public:
45 // testable helpers
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000046 static GraphPtr LoadGraphFromFile(const char* fileName, std::string& fileContent);
Kevin May43a799c2019-02-08 16:31:42 +000047 static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
48 static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
49 static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
50 static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr);
51 static LayerBaseRawPtrVector GetGraphOutputs(const GraphPtr& graphPtr);
52 static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
53 static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000054 armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
55 unsigned int layerIndex);
Kevin May43a799c2019-02-08 16:31:42 +000056
57private:
58 // No copying allowed until it is wanted and properly implemented
59 DeserializeParser(const DeserializeParser&) = delete;
60 DeserializeParser& operator=(const DeserializeParser&) = delete;
61
62 /// Create the network from an already loaded flatbuffers graph
63 armnn::INetworkPtr CreateNetworkFromGraph();
64
65 // signature for the parser functions
66 using LayerParsingFunction = void(DeserializeParser::*)(unsigned int layerIndex);
67
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000068 void ParseUnsupportedLayer(unsigned int layerIndex);
69 void ParseAdd(unsigned int layerIndex);
70 void ParseMultiplication(unsigned int layerIndex);
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000071 void ParsePooling2d(unsigned int layerIndex);
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000072 void ParseSoftmax(unsigned int layerIndex);
Kevin May43a799c2019-02-08 16:31:42 +000073
74 void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
75 void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
76 void RegisterInputSlots(uint32_t layerIndex,
77 armnn::IConnectableLayer* layer);
78 void RegisterOutputSlots(uint32_t layerIndex,
79 armnn::IConnectableLayer* layer);
80 void ResetParser();
81
82 void SetupInputLayers();
83 void SetupOutputLayers();
84
85 /// The network we're building. Gets cleared after it is passed to the user
86 armnn::INetworkPtr m_Network;
87 GraphPtr m_Graph;
88 std::vector<LayerParsingFunction> m_ParserFunctions;
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000089 std::string m_layerName;
Kevin May43a799c2019-02-08 16:31:42 +000090
Nattapat Chaimanowong43e78642019-02-13 15:56:24 +000091 /// This holds the data of the file that was read in from CreateNetworkFromBinaryFile
92 /// Needed for m_Graph to point to
93 std::string m_FileContent;
94
Kevin May43a799c2019-02-08 16:31:42 +000095 /// A mapping of an output slot to each of the input slots it should be connected to
96 /// The outputSlot is from the layer that creates this tensor as one of its outputs
97 /// The inputSlots are from the layers that use this tensor as one of their inputs
98 struct Slots
99 {
100 armnn::IOutputSlot* outputSlot;
101 std::vector<armnn::IInputSlot*> inputSlots;
102
103 Slots() : outputSlot(nullptr) { }
104 };
105 typedef std::vector<Slots> Connection;
106 std::vector<Connection> m_GraphConnections;
107};
108
109}