blob: aee647c636ac2749f0f16ef881ed87bccf3b0778 [file] [log] [blame]
Kevin May43a799c2019-02-08 16:31:42 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "armnn/INetwork.hpp"
9#include "armnnDeserializeParser/IDeserializeParser.hpp"
10#include <Schema_generated.h>
11
12namespace armnnDeserializeParser
13{
14class DeserializeParser : public IDeserializeParser
15{
16public:
17 // Shorthands for deserializer types
18 using GraphPtr = const armnn::armnnSerializer::SerializedGraph *;
19 using TensorRawPtr = const armnn::armnnSerializer::TensorInfo *;
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000020 using PoolingDescriptor = const armnn::armnnSerializer::Pooling2dDescriptor *;
Kevin May43a799c2019-02-08 16:31:42 +000021 using TensorRawPtrVector = std::vector<TensorRawPtr>;
22 using LayerRawPtr = const armnn::armnnSerializer::LayerBase *;
23 using LayerBaseRawPtr = const armnn::armnnSerializer::LayerBase *;
24 using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
25
26public:
27
Derek Lamberti2b183fb2019-02-18 16:36:57 +000028 /// Create an input network from binary file contents
29 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
Kevin May43a799c2019-02-08 16:31:42 +000030
Derek Lamberti2b183fb2019-02-18 16:36:57 +000031 /// Create an input network from a binary input stream
32 armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override;
Kevin May43a799c2019-02-08 16:31:42 +000033
34 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
Derek Lamberti2b183fb2019-02-18 16:36:57 +000035 BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override;
Kevin May43a799c2019-02-08 16:31:42 +000036
37 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
Derek Lamberti2b183fb2019-02-18 16:36:57 +000038 BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override;
Kevin May43a799c2019-02-08 16:31:42 +000039
40 DeserializeParser();
41 ~DeserializeParser() {}
42
43public:
44 // testable helpers
Kevin May43a799c2019-02-08 16:31:42 +000045 static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
Derek Lamberti2b183fb2019-02-18 16:36:57 +000046 static GraphPtr LoadGraphFromBinary(std::istream& binaryContent);
Kevin May43a799c2019-02-08 16:31:42 +000047 static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
48 static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
49 static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr);
50 static LayerBaseRawPtrVector GetGraphOutputs(const GraphPtr& graphPtr);
51 static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
52 static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000053 armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
54 unsigned int layerIndex);
Saoirse Stewart263829c2019-02-19 15:54:14 +000055 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
56 const std::vector<uint32_t> & targetDimsIn);
Kevin May43a799c2019-02-08 16:31:42 +000057
58private:
59 // No copying allowed until it is wanted and properly implemented
60 DeserializeParser(const DeserializeParser&) = delete;
61 DeserializeParser& operator=(const DeserializeParser&) = delete;
62
63 /// Create the network from an already loaded flatbuffers graph
64 armnn::INetworkPtr CreateNetworkFromGraph();
65
66 // signature for the parser functions
67 using LayerParsingFunction = void(DeserializeParser::*)(unsigned int layerIndex);
68
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000069 void ParseUnsupportedLayer(unsigned int layerIndex);
70 void ParseAdd(unsigned int layerIndex);
71 void ParseMultiplication(unsigned int layerIndex);
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000072 void ParsePooling2d(unsigned int layerIndex);
Saoirse Stewart263829c2019-02-19 15:54:14 +000073 void ParseReshape(unsigned int layerIndex);
Aron Virginas-Tarfc413c02019-02-13 15:41:52 +000074 void ParseSoftmax(unsigned int layerIndex);
Kevin May43a799c2019-02-08 16:31:42 +000075
76 void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
77 void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
78 void RegisterInputSlots(uint32_t layerIndex,
79 armnn::IConnectableLayer* layer);
80 void RegisterOutputSlots(uint32_t layerIndex,
81 armnn::IConnectableLayer* layer);
82 void ResetParser();
83
84 void SetupInputLayers();
85 void SetupOutputLayers();
86
87 /// The network we're building. Gets cleared after it is passed to the user
88 armnn::INetworkPtr m_Network;
89 GraphPtr m_Graph;
90 std::vector<LayerParsingFunction> m_ParserFunctions;
Saoirse Stewart3166c3e2019-02-18 15:24:53 +000091 std::string m_layerName;
Kevin May43a799c2019-02-08 16:31:42 +000092
Kevin May43a799c2019-02-08 16:31:42 +000093 /// A mapping of an output slot to each of the input slots it should be connected to
94 /// The outputSlot is from the layer that creates this tensor as one of its outputs
95 /// The inputSlots are from the layers that use this tensor as one of their inputs
96 struct Slots
97 {
98 armnn::IOutputSlot* outputSlot;
99 std::vector<armnn::IInputSlot*> inputSlots;
100
101 Slots() : outputSlot(nullptr) { }
102 };
103 typedef std::vector<Slots> Connection;
104 std::vector<Connection> m_GraphConnections;
105};
106
107}